mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 支持向量模型
This commit is contained in:
parent
48d6812233
commit
aeff8ab0b4
|
|
@ -111,6 +111,8 @@ class FlowParamsSerializer(serializers.Serializer):
|
|||
|
||||
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
|
||||
|
||||
user_id = serializers.CharField(required=True, error_messages=ErrMessage.char("用户id"))
|
||||
|
||||
|
||||
class INode:
|
||||
def __init__(self, node, workflow_params, workflow_manage):
|
||||
|
|
|
|||
|
|
@ -23,11 +23,13 @@ from setting.models_provider import get_model
|
|||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def get_model_by_id(_id):
|
||||
def get_model_by_id(_id, user_id):
|
||||
model = QuerySet(Model).filter(id=_id).first()
|
||||
if model is None:
|
||||
raise Exception("模型不存在")
|
||||
return get_model(model)
|
||||
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||
raise Exception(f"无权限使用此模型:{model.name}")
|
||||
return model
|
||||
|
||||
|
||||
def get_embedding_id(dataset_id_list):
|
||||
|
|
@ -53,7 +55,8 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||
if len(dataset_id_list) == 0:
|
||||
return get_none_result(question)
|
||||
model_id = get_embedding_id(dataset_id_list)
|
||||
embedding_model = EmbeddingModelManage.get_model(model_id, get_model_by_id)
|
||||
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
embedding_value = embedding_model.embed_query(question)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
exclude_document_id_list = [str(document.id) for document in
|
||||
|
|
|
|||
|
|
@ -221,11 +221,13 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
stream = self.data.get('stream')
|
||||
client_id = self.data.get('client_id')
|
||||
client_type = self.data.get('client_type')
|
||||
user_id = chat_info.application.user_id
|
||||
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
|
||||
{'history_chat_record': chat_info.chat_record_list, 'question': message,
|
||||
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
|
||||
'stream': stream,
|
||||
're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type))
|
||||
're_chat': re_chat,
|
||||
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type))
|
||||
r = work_flow_manage.run()
|
||||
return r
|
||||
|
||||
|
|
|
|||
|
|
@ -260,6 +260,7 @@ class ChatSerializers(serializers.Serializer):
|
|||
|
||||
class OpenWorkFlowChat(serializers.Serializer):
|
||||
work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流"))
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
|
||||
def open(self):
|
||||
self.is_valid(raise_exception=True)
|
||||
|
|
@ -270,7 +271,8 @@ class ChatSerializers(serializers.Serializer):
|
|||
dataset_setting={},
|
||||
model_setting={},
|
||||
problem_optimization=None,
|
||||
type=ApplicationTypeChoices.WORK_FLOW
|
||||
type=ApplicationTypeChoices.WORK_FLOW,
|
||||
user_id=self.data.get('user_id')
|
||||
)
|
||||
work_flow_version = WorkFlowVersion(work_flow=work_flow)
|
||||
chat_cache.set(chat_id,
|
||||
|
|
@ -333,7 +335,8 @@ class ChatSerializers(serializers.Serializer):
|
|||
application = Application(id=None, dialogue_number=3, model=model,
|
||||
dataset_setting=self.data.get('dataset_setting'),
|
||||
model_setting=self.data.get('model_setting'),
|
||||
problem_optimization=self.data.get('problem_optimization'))
|
||||
problem_optimization=self.data.get('problem_optimization'),
|
||||
user_id=user_id)
|
||||
chat_cache.set(chat_id,
|
||||
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||
[str(document.id) for document in
|
||||
|
|
|
|||
Loading…
Reference in New Issue