fix: chat bugs (#3308)

This commit is contained in:
shaohuzhang1 2025-06-19 14:53:24 +08:00 committed by GitHub
parent 03ec0f3fdf
commit 598b72fd12
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 83 additions and 68 deletions

View File

@ -75,7 +75,7 @@ class IChatStep(IBaseChatPipelineStep):
no_references_setting = NoReferencesSetting(required=True,
label=_("No reference segment settings"))
user_id = serializers.UUIDField(required=True, label=_("User ID"))
workspace_id = serializers.CharField(required=True, label=_("Workspace ID"))
model_setting = serializers.DictField(required=True, allow_null=True,
label=_("Model settings"))
@ -102,7 +102,7 @@ class IChatStep(IBaseChatPipelineStep):
chat_id, problem_text,
post_response_handler: PostResponseHandler,
model_id: str = None,
user_id: str = None,
workspace_id: str = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,

View File

@ -26,7 +26,7 @@ from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.flow.tools import Reasoning
from application.models import ApplicationChatUserStats, ChatUserType
from models_provider.tools import get_model_instance_by_model_user_id
from models_provider.tools import get_model_instance_by_model_workspace_id
def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
@ -157,7 +157,7 @@ class BaseChatStep(IChatStep):
problem_text,
post_response_handler: PostResponseHandler,
model_id: str = None,
user_id: str = None,
workspace_id: str = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
@ -167,8 +167,8 @@ class BaseChatStep(IChatStep):
model_params_setting=None,
model_setting=None,
**kwargs):
chat_model = get_model_instance_by_model_user_id(model_id, user_id,
**model_params_setting) if model_id is not None else None
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting) if model_id is not None else None
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,

View File

@ -27,7 +27,7 @@ class IResetProblemStep(IBaseChatPipelineStep):
label=_("History Questions"))
# 大语言模型
model_id = serializers.UUIDField(required=False, allow_null=True, label=_("Model id"))
user_id = serializers.UUIDField(required=True, label=_("User ID"))
workspace_id = serializers.CharField(required=True, label=_("User ID"))
problem_optimization_prompt = serializers.CharField(required=False, max_length=102400,
label=_("Question completion prompt"))
@ -50,6 +50,6 @@ class IResetProblemStep(IBaseChatPipelineStep):
@abstractmethod
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
problem_optimization_prompt=None,
user_id=None,
workspace_id=None,
**kwargs):
pass

View File

@ -14,7 +14,7 @@ from langchain.schema import HumanMessage
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
from application.models import ChatRecord
from common.utils.split_model import flat_map
from models_provider.tools import get_model_instance_by_model_user_id
from models_provider.tools import get_model_instance_by_model_workspace_id
prompt = _(
"() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the <data></data> tag")
@ -23,9 +23,9 @@ prompt = _(
class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None,
problem_optimization_prompt=None,
user_id=None,
workspace_id=None,
**kwargs) -> str:
chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id) if model_id is not None else None
if chat_model is None:
return problem_text
start_index = len(history_chat_record) - 3

View File

@ -44,7 +44,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message=_("The type only supports embedding|keywords|blend"), code=500)
], label=_("Retrieval Mode"))
user_id = serializers.UUIDField(required=True, label=_("User ID"))
workspace_id = serializers.CharField(required=True, label=_("Workspace ID"))
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer
@ -58,19 +58,19 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
workspace_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
"""
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
:param similarity: 相关性
:param top_n: 查询多少条
:param problem_text: 用户问题
:param knowledge_id_list: 需要查询的数据集id列表
:param knowledge_id_list: 需要查询的数据集id列表
:param exclude_document_id_list: 需要排除的文档id
:param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题
:param search_mode 检索模式
:param user_id 用户id
:param workspace_id 工作空间id
:return: 段落列表
"""
pass

View File

@ -25,13 +25,13 @@ from models_provider.models import Model
from models_provider.tools import get_model
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
def get_model_by_id(_id, workspace_id):
model = QuerySet(Model).filter(id=_id, model_type="EMBEDDING").first()
if model is None:
raise Exception(_("Model does not exist"))
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
message = lazy_format(_('No permission to use this model {model_name}'), model_name=model.name)
raise Exception(message)
if model.workspace_id is not None:
if model.workspace_id != workspace_id:
raise Exception(_("Model does not exist"))
return model
@ -50,13 +50,13 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
workspace_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
if len(knowledge_id_list) == 0:
return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
model_id = get_embedding_id(knowledge_id_list)
model = get_model_by_id(model_id, user_id)
model = get_model_by_id(model_id, workspace_id)
self.context['model_name'] = model.name
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text)

View File

@ -11,7 +11,6 @@ import json
import re
import time
from functools import reduce
from types import AsyncGeneratorType
from typing import List, Dict
from django.db.models import QuerySet
@ -24,7 +23,7 @@ from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
from application.flow.tools import Reasoning
from models_provider.models import Model
from models_provider.tools import get_model_credential, get_model_instance_by_model_user_id
from models_provider.tools import get_model_credential, get_model_instance_by_model_workspace_id
tool_message_template = """
<details>
@ -206,8 +205,9 @@ class BaseChatNode(IChatNode):
model_setting = {'reasoning_content_enable': False, 'reasoning_content_end': '</think>',
'reasoning_content_start': '<think>'}
self.context['model_setting'] = model_setting
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
workspace_id = self.workflow_manage.get_body().get('workspace_id')
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
self.runtime_node_id)
self.context['history_message'] = history_message

View File

@ -9,7 +9,7 @@ from application.flow.i_step_node import NodeResult
from application.flow.step_node.image_generate_step_node.i_image_generate_node import IImageGenerateNode
from common.utils.common import bytes_to_uploaded_file
from oss.serializers.file import FileSerializer
from models_provider.tools import get_model_instance_by_model_user_id
from models_provider.tools import get_model_instance_by_model_workspace_id
class BaseImageGenerateNode(IImageGenerateNode):
@ -25,8 +25,9 @@ class BaseImageGenerateNode(IImageGenerateNode):
**kwargs) -> NodeResult:
print(model_params_setting)
application = self.workflow_manage.work_flow_post_handler.chat_info.application
tti_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
workspace_id = self.workflow_manage.get_body().get('workspace_id')
tti_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)

View File

@ -11,7 +11,7 @@ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AI
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode
from knowledge.models import File
from models_provider.tools import get_model_instance_by_model_user_id
from models_provider.tools import get_model_instance_by_model_workspace_id
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
@ -79,9 +79,9 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
# 处理不正确的参数
if image is None or not isinstance(image, list):
image = []
print(model_params_setting)
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
workspace_id = self.workflow_manage.get_body().get('workspace_id')
image_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
# 执行详情中的历史消息不需要图片内容
history_message = self.get_history_message_for_details(history_chat_record, dialogue_number)
self.context['history_message'] = history_message

View File

@ -18,7 +18,7 @@ from langchain_core.messages import BaseMessage
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.question_node.i_question_node import IQuestionNode
from models_provider.models import Model
from models_provider.tools import get_model_instance_by_model_user_id, get_model_credential
from models_provider.tools import get_model_instance_by_model_workspace_id, get_model_credential
def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str):
@ -87,8 +87,9 @@ class BaseQuestionNode(IQuestionNode):
**kwargs) -> NodeResult:
if model_params_setting is None:
model_params_setting = get_default_model_params_setting(model_id)
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
workspace_id = self.workflow_manage.get_body().get('workspace_id')
chat_model = get_model_instance_by_model_workspace_id(model_id, workspace_id,
**model_params_setting)
history_message = self.get_history_message(history_chat_record, dialogue_number)
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)

View File

@ -12,7 +12,7 @@ from langchain_core.documents import Document
from application.flow.i_step_node import NodeResult
from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode
from models_provider.tools import get_model_instance_by_model_user_id
from models_provider.tools import get_model_instance_by_model_workspace_id
def merge_reranker_list(reranker_list, result=None):
@ -78,8 +78,9 @@ class BaseRerankerNode(IRerankerNode):
self.context['document_list'] = [{'page_content': document.page_content, 'metadata': document.metadata} for
document in documents]
self.context['question'] = question
reranker_model = get_model_instance_by_model_user_id(reranker_model_id,
self.flow_params_serializer.data.get('user_id'),
workspace_id = self.workflow_manage.get_body().get('workspace_id')
reranker_model = get_model_instance_by_model_workspace_id(reranker_model_id,
workspace_id,
top_n=top_n)
result = reranker_model.compress_documents(
documents,

View File

@ -19,7 +19,7 @@ from common.db.search import native_search
from common.utils.common import get_file_content
from knowledge.models import Document, Paragraph, Knowledge, SearchMode
from maxkb.conf import PROJECT_DIR
from models_provider.tools import get_model_instance_by_model_user_id
from models_provider.tools import get_model_instance_by_model_workspace_id
def get_embedding_id(dataset_id_list):
@ -67,7 +67,8 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
if len(dataset_id_list) == 0:
return get_none_result(question)
model_id = get_embedding_id(dataset_id_list)
embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
workspace_id = self.workflow_manage.get_body().get('workspace_id')
embedding_model = get_model_instance_by_model_workspace_id(model_id, workspace_id)
embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in

View File

@ -9,7 +9,7 @@ from application.flow.i_step_node import NodeResult
from application.flow.step_node.speech_to_text_step_node.i_speech_to_text_node import ISpeechToTextNode
from common.utils.common import split_and_transcribe, any_to_mp3
from knowledge.models import File
from models_provider.tools import get_model_instance_by_model_user_id
from models_provider.tools import get_model_instance_by_model_workspace_id
class BaseSpeechToTextNode(ISpeechToTextNode):
@ -20,7 +20,8 @@ class BaseSpeechToTextNode(ISpeechToTextNode):
self.answer_text = details.get('answer')
def execute(self, stt_model_id, chat_id, audio, **kwargs) -> NodeResult:
stt_model = get_model_instance_by_model_user_id(stt_model_id, self.flow_params_serializer.data.get('user_id'))
workspace_id = self.workflow_manage.get_body().get('workspace_id')
stt_model = get_model_instance_by_model_workspace_id(stt_model_id, workspace_id)
audio_list = audio
self.context['audio_list'] = audio

View File

@ -6,8 +6,8 @@ from django.core.files.uploadedfile import InMemoryUploadedFile
from application.flow.i_step_node import NodeResult
from application.flow.step_node.text_to_speech_step_node.i_text_to_speech_node import ITextToSpeechNode
from models_provider.tools import get_model_instance_by_model_workspace_id
from oss.serializers.file import FileSerializer
from models_provider.tools import get_model_instance_by_model_user_id
def bytes_to_uploaded_file(file_bytes, file_name="generated_audio.mp3"):
@ -42,8 +42,9 @@ class BaseTextToSpeechNode(ITextToSpeechNode):
content, model_params_setting=None,
**kwargs) -> NodeResult:
self.context['content'] = content
model = get_model_instance_by_model_user_id(tts_model_id, self.flow_params_serializer.data.get('user_id'),
**model_params_setting)
workspace_id = self.workflow_manage.get_body().get('workspace_id')
model = get_model_instance_by_model_workspace_id(tts_model_id, workspace_id,
**model_params_setting)
audio_byte = model.text_to_speech(content)
# 需要把这个音频文件存储到数据库中
file_name = 'generated_audio.mp3'

View File

@ -538,6 +538,7 @@ class ApplicationSerializer(serializers.Serializer):
class ApplicationOperateSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True, label=_("Application ID"))
user_id = serializers.UUIDField(required=True, label=_("User ID"))
workspace_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, label=_("Workspace ID"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -682,7 +683,6 @@ class ApplicationOperateSerializer(serializers.Serializer):
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
application.__setattr__(update_key, instance.get(update_key))
print(application.name)
application.save()
if 'knowledge_id_list' in instance:
@ -690,11 +690,11 @@ class ApplicationOperateSerializer(serializers.Serializer):
# 当前用户可修改关联的知识库列表
application_knowledge_id_list = [str(knowledge.id) for knowledge in
self.list_knowledge(with_valid=False)]
for dataset_id in knowledge_id_list:
if not application_knowledge_id_list.__contains__(dataset_id):
for knowledge_id in knowledge_id_list:
if not application_knowledge_id_list.__contains__(knowledge_id):
message = lazy_format(_('Unknown knowledge base id {dataset_id}, unable to associate'),
dataset_id=dataset_id)
raise AppApiException(500, message)
dataset_id=knowledge_id)
raise AppApiException(500, str(message))
self.save_application_knowledge_mapping(application_knowledge_id_list, knowledge_id_list, application_id)
return self.one(with_valid=False)
@ -707,8 +707,8 @@ class ApplicationOperateSerializer(serializers.Serializer):
knowledge_list = self.list_knowledge(with_valid=False)
mapping_knowledge_id_list = [akm.knowledge_id for akm in
QuerySet(ApplicationKnowledgeMapping).filter(application_id=application_id)]
knowledge_id_list = [d.get('id') for d in
list(filter(lambda row: mapping_knowledge_id_list.__contains__(row.get('id')),
knowledge_id_list = [d.id for d in
list(filter(lambda row: mapping_knowledge_id_list.__contains__(row.id),
knowledge_list))]
return {**ApplicationSerializerModel(application).data,
'knowledge_id_list': knowledge_id_list}
@ -729,5 +729,5 @@ class ApplicationOperateSerializer(serializers.Serializer):
application_id=application_id).delete()
# 插入
QuerySet(ApplicationKnowledgeMapping).bulk_create(
[ApplicationKnowledgeMapping(application_id=application_id, dataset_id=dataset_id) for dataset_id in
[ApplicationKnowledgeMapping(application_id=application_id, knowledge_id=knowledge_id) for knowledge_id in
knowledge_id_list]) if len(knowledge_id_list) > 0 else None

View File

@ -98,7 +98,7 @@ class ChatInfo:
self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
'search_mode': self.application.knowledge_setting.get('search_mode') or 'embedding',
'no_references_setting': self.get_no_references_setting(self.application.knowledge_setting, model_setting),
'user_id': self.application.user_id,
'workspace_id': self.application.workspace_id,
'application_id': self.application.id
}

View File

@ -130,6 +130,7 @@ class ApplicationAPI(APIView):
def post(self, request: Request, workspace_id: str, application_id: str):
return ApplicationOperateSerializer(
data={'application_id': application_id,
'workspace_id': workspace_id,
'user_id': request.user.id}).export(request.data)
class Operate(APIView):
@ -148,11 +149,12 @@ class ApplicationAPI(APIView):
RoleConstants.WORKSPACE_MANAGE.get_workspace_role())
@log(menu='Application', operate='Deleting application',
get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id')),
)
def delete(self, request: Request, workspace_id: str, application_id: str):
return result.success(ApplicationOperateSerializer(
data={'application_id': application_id, 'user_id': request.user.id}).delete(
data={'application_id': application_id, 'user_id': request.user.id,
'workspace_id': workspace_id, }).delete(
with_valid=True))
@extend_schema(
@ -173,7 +175,8 @@ class ApplicationAPI(APIView):
def put(self, request: Request, workspace_id: str, application_id: str):
return result.success(
ApplicationOperateSerializer(
data={'application_id': application_id, 'user_id': request.user.id}).edit(
data={'application_id': application_id, 'user_id': request.user.id,
'workspace_id': workspace_id, }).edit(
request.data))
@extend_schema(
@ -190,7 +193,8 @@ class ApplicationAPI(APIView):
RoleConstants.WORKSPACE_MANAGE.get_workspace_role(), RoleConstants.ADMIN)
def get(self, request: Request, workspace_id: str, application_id: str):
return result.success(ApplicationOperateSerializer(
data={'application_id': application_id, 'user_id': request.user.id}).one())
data={'application_id': application_id, 'user_id': request.user.id,
'workspace_id': workspace_id, }).one())
class Publish(APIView):
authentication_classes = [TokenAuth]
@ -207,9 +211,10 @@ class ApplicationAPI(APIView):
)
@log(menu='Application', operate='Publishing an application',
get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id')),
)
def put(self, request: Request, workspace_id: str, application_id: str):
return result.success(
ApplicationOperateSerializer(
data={'application_id': application_id, 'user_id': request.user.id}).publish(request.data))
data={'application_id': application_id, 'user_id': request.user.id,
'workspace_id': workspace_id, }).publish(request.data))

View File

@ -366,7 +366,7 @@ class OpenChatSerializers(serializers.Serializer):
chat_user_id = self.data.get("chat_user_id")
chat_user_type = self.data.get("chat_user_type")
debug = self.data.get("debug")
knowledge_id_list = [str(row.dataset_id) for row in
knowledge_id_list = [str(row.knowledge_id) for row in
QuerySet(ApplicationKnowledgeMapping).filter(
application_id=application_id)]
chat_id = str(uuid.uuid7())

View File

@ -103,21 +103,25 @@ def is_valid_credential(provider, model_type, model_name, model_credential: Dict
raise_exception)
def get_model_by_id(_id, user_id):
def get_model_by_id(_id, workspace_id):
model = QuerySet(Model).filter(id=_id).first()
# 手动关闭数据库连接
connection.close()
if model is None:
raise Exception(_('Model does not exist'))
if model.workspace_id:
if model.workspace_id != workspace_id:
raise Exception(_('Model does not exist'))
return model
def get_model_instance_by_model_user_id(model_id, user_id, **kwargs):
def get_model_instance_by_model_workspace_id(model_id, workspace_id, **kwargs):
"""
获取模型实例,根据模型相关数据
@param model_id: 模型id
@param user_id: 用户id
@return: 模型实例
@param model_id: 模型id
@param workspace_id: 工作空间id
@return: 模型实例
"""
model = get_model_by_id(model_id, user_id)
model = get_model_by_id(model_id, workspace_id)
return ModelManage.get_model(model_id, lambda _id: get_model(model, **kwargs))