feat: 支持向量模型

This commit is contained in:
shaohuzhang1 2024-07-18 19:11:29 +08:00
parent 48d6812233
commit aeff8ab0b4
4 changed files with 16 additions and 6 deletions

View File

@ -111,6 +111,8 @@ class FlowParamsSerializer(serializers.Serializer):
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
user_id = serializers.CharField(required=True, error_messages=ErrMessage.char("用户id"))
class INode:
def __init__(self, node, workflow_params, workflow_manage):

View File

@ -23,11 +23,13 @@ from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id):
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
return get_model(model)
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model
def get_embedding_id(dataset_id_list):
@ -53,7 +55,8 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
if len(dataset_id_list) == 0:
return get_none_result(question)
model_id = get_embedding_id(dataset_id_list)
embedding_model = EmbeddingModelManage.get_model(model_id, get_model_by_id)
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in

View File

@ -221,11 +221,13 @@ class ChatMessageSerializer(serializers.Serializer):
stream = self.data.get('stream')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
user_id = chat_info.application.user_id
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
{'history_chat_record': chat_info.chat_record_list, 'question': message,
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
'stream': stream,
're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type))
're_chat': re_chat,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type))
r = work_flow_manage.run()
return r

View File

@ -260,6 +260,7 @@ class ChatSerializers(serializers.Serializer):
class OpenWorkFlowChat(serializers.Serializer):
work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
def open(self):
self.is_valid(raise_exception=True)
@ -270,7 +271,8 @@ class ChatSerializers(serializers.Serializer):
dataset_setting={},
model_setting={},
problem_optimization=None,
type=ApplicationTypeChoices.WORK_FLOW
type=ApplicationTypeChoices.WORK_FLOW,
user_id=self.data.get('user_id')
)
work_flow_version = WorkFlowVersion(work_flow=work_flow)
chat_cache.set(chat_id,
@ -333,7 +335,8 @@ class ChatSerializers(serializers.Serializer):
application = Application(id=None, dialogue_number=3, model=model,
dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'),
problem_optimization=self.data.get('problem_optimization'))
problem_optimization=self.data.get('problem_optimization'),
user_id=user_id)
chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list,
[str(document.id) for document in