mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 模型管理支持向量模型,知识库可以关联向量模型
feat: 模型管理支持向量模型,知识库可以关联向量模型
This commit is contained in:
commit
d3d09b10ec
|
|
@ -43,6 +43,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
|
||||
message="类型只支持register|reset_password", code=500)
|
||||
], error_messages=ErrMessage.char("检索模式"))
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
|
||||
return self.InstanceSerializer
|
||||
|
|
@ -56,6 +57,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
search_mode: str = None,
|
||||
user_id=None,
|
||||
**kwargs) -> List[ParagraphPipelineModel]:
|
||||
"""
|
||||
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
|
||||
|
|
@ -67,6 +69,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
|
|||
:param exclude_paragraph_id_list: 需要排除段落id
|
||||
:param padding_problem_text 补全问题
|
||||
:param search_mode 检索模式
|
||||
:param user_id 用户id
|
||||
:return: 段落列表
|
||||
"""
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,22 +13,48 @@ from django.db.models import QuerySet
|
|||
|
||||
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
|
||||
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.config.embedding_config import VectorStore, ModelManage
|
||||
from common.db.search import native_search
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Paragraph
|
||||
from dataset.models import Paragraph, DataSet
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def get_model_by_id(_id, user_id):
|
||||
model = QuerySet(Model).filter(id=_id).first()
|
||||
if model is None:
|
||||
raise Exception("模型不存在")
|
||||
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||
raise Exception(f"无权限使用此模型:{model.name}")
|
||||
return model
|
||||
|
||||
|
||||
def get_embedding_id(dataset_id_list):
|
||||
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||
raise Exception("知识库未向量模型不一致")
|
||||
if len(dataset_list) == 0:
|
||||
raise Exception("知识库设置错误,请重新设置知识库")
|
||||
return dataset_list[0].embedding_mode_id
|
||||
|
||||
|
||||
class BaseSearchDatasetStep(ISearchDatasetStep):
|
||||
|
||||
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
|
||||
search_mode: str = None,
|
||||
user_id=None,
|
||||
**kwargs) -> List[ParagraphPipelineModel]:
|
||||
if len(dataset_id_list) == 0:
|
||||
return []
|
||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||
embedding_model = EmbeddingModel.get_embedding_model()
|
||||
model_id = get_embedding_id(dataset_id_list)
|
||||
model = get_model_by_id(model_id, user_id)
|
||||
self.context['model_name'] = model.name
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
embedding_value = embedding_model.embed_query(exec_problem_text)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
|
||||
|
|
@ -101,7 +127,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
|
|||
'run_time': self.context['run_time'],
|
||||
'problem_text': step_args.get(
|
||||
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
|
||||
'model_name': EmbeddingModel.get_embedding_model().model_name,
|
||||
'model_name': self.context.get('model_name'),
|
||||
'message_tokens': 0,
|
||||
'answer_tokens': 0,
|
||||
'cost': 0
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ class FlowParamsSerializer(serializers.Serializer):
|
|||
|
||||
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
|
||||
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id"))
|
||||
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -13,20 +13,50 @@ from django.db.models import QuerySet
|
|||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
|
||||
from common.config.embedding_config import EmbeddingModel, VectorStore
|
||||
from common.config.embedding_config import VectorStore, ModelManage
|
||||
from common.db.search import native_search
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Document, Paragraph
|
||||
from dataset.models import Document, Paragraph, DataSet
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
def get_model_by_id(_id, user_id):
|
||||
model = QuerySet(Model).filter(id=_id).first()
|
||||
if model is None:
|
||||
raise Exception("模型不存在")
|
||||
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
|
||||
raise Exception(f"无权限使用此模型:{model.name}")
|
||||
return model
|
||||
|
||||
|
||||
def get_embedding_id(dataset_id_list):
|
||||
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||
raise Exception("知识库未向量模型不一致")
|
||||
if len(dataset_list) == 0:
|
||||
raise Exception("知识库设置错误,请重新设置知识库")
|
||||
return dataset_list[0].embedding_mode_id
|
||||
|
||||
|
||||
def get_none_result(question):
|
||||
return NodeResult(
|
||||
{'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
|
||||
'directly_return': ''}, {})
|
||||
|
||||
|
||||
class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
||||
def execute(self, dataset_id_list, dataset_setting, question,
|
||||
exclude_paragraph_id_list=None,
|
||||
**kwargs) -> NodeResult:
|
||||
self.context['question'] = question
|
||||
embedding_model = EmbeddingModel.get_embedding_model()
|
||||
if len(dataset_id_list) == 0:
|
||||
return get_none_result(question)
|
||||
model_id = get_embedding_id(dataset_id_list)
|
||||
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
|
||||
embedding_value = embedding_model.embed_query(question)
|
||||
vector = VectorStore.get_embedding_vector()
|
||||
exclude_document_id_list = [str(document.id) for document in
|
||||
|
|
@ -37,7 +67,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
|
|||
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
|
||||
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
|
||||
if embedding_list is None:
|
||||
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
|
||||
return get_none_result(question)
|
||||
paragraph_list = self.list_paragraph(embedding_list, vector)
|
||||
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
|
||||
result = sorted(result, key=lambda p: p.get('similarity'), reverse=True)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
# Generated by Django 4.2.13 on 2024-07-15 15:52
|
||||
|
||||
import application.models.application
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0009_application_type_application_work_flow_and_more'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='details',
|
||||
field=models.JSONField(default=dict, encoder=application.models.application.DateEncoder, verbose_name='对话详情'),
|
||||
),
|
||||
]
|
||||
|
|
@ -26,7 +26,7 @@ from rest_framework import serializers
|
|||
from application.flow.workflow_manage import Flow
|
||||
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
|
||||
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.db.search import get_dynamics_model, native_search, native_page_search
|
||||
from common.db.sql_execute import select_list
|
||||
|
|
@ -36,7 +36,7 @@ from common.util.common import valid_license
|
|||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import DataSet, Document, Image
|
||||
from dataset.serializers.common_serializers import list_paragraph
|
||||
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import AuthOperate
|
||||
from setting.models.model_management import Model
|
||||
|
|
@ -415,12 +415,13 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
QuerySet(Document).filter(
|
||||
dataset_id__in=dataset_id_list,
|
||||
is_active=False)]
|
||||
model = get_embedding_model_by_dataset_id_list(dataset_id_list)
|
||||
# 向量库检索
|
||||
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
|
||||
self.data.get('top_number'),
|
||||
self.data.get('similarity'),
|
||||
SearchMode(self.data.get('search_mode')),
|
||||
EmbeddingModel.get_embedding_model())
|
||||
model)
|
||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||
|
|
@ -522,12 +523,14 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
if not QuerySet(Application).filter(id=self.data.get('application_id')).exists():
|
||||
raise AppApiException(500, '不存在的应用id')
|
||||
|
||||
def list_model(self, with_valid=True):
|
||||
def list_model(self, model_type=None, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid()
|
||||
if model_type is None:
|
||||
model_type = "LLM"
|
||||
application = QuerySet(Application).filter(id=self.data.get("application_id")).first()
|
||||
return ModelSerializer.Query(
|
||||
data={'user_id': application.user_id}).list(
|
||||
data={'user_id': application.user_id, 'model_type': model_type}).list(
|
||||
with_valid=True)
|
||||
|
||||
def delete(self, with_valid=True):
|
||||
|
|
|
|||
|
|
@ -88,7 +88,9 @@ class ChatInfo:
|
|||
'no_references_setting': self.application.dataset_setting.get(
|
||||
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
|
||||
'status': 'ai_questioning',
|
||||
'value': '{question}'}
|
||||
'value': '{question}',
|
||||
},
|
||||
'user_id': self.application.user_id
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -221,11 +223,13 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
stream = self.data.get('stream')
|
||||
client_id = self.data.get('client_id')
|
||||
client_type = self.data.get('client_type')
|
||||
user_id = chat_info.application.user_id
|
||||
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
|
||||
{'history_chat_record': chat_info.chat_record_list, 'question': message,
|
||||
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
|
||||
'stream': stream,
|
||||
're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type))
|
||||
're_chat': re_chat,
|
||||
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type))
|
||||
r = work_flow_manage.run()
|
||||
return r
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from application.models.api_key_model import ApplicationAccessToken
|
|||
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
|
||||
ModelSettingSerializer
|
||||
from application.serializers.chat_message_serializers import ChatInfo
|
||||
from common.config.embedding_config import ModelManage
|
||||
from common.constants.permission_constants import RoleConstants
|
||||
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
|
||||
from common.event import ListenerManagement
|
||||
|
|
@ -39,8 +40,10 @@ from common.util.file_util import get_file_content
|
|||
from common.util.lock import try_lock, un_lock
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||
from setting.models import Model
|
||||
from setting.models_provider import get_model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -241,12 +244,7 @@ class ChatSerializers(serializers.Serializer):
|
|||
application_id=application_id)]
|
||||
chat_model = None
|
||||
if model is not None:
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(
|
||||
model.credential)),
|
||||
streaming=True)
|
||||
|
||||
chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model))
|
||||
chat_id = str(uuid.uuid1())
|
||||
chat_cache.set(chat_id,
|
||||
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||
|
|
@ -259,6 +257,7 @@ class ChatSerializers(serializers.Serializer):
|
|||
|
||||
class OpenWorkFlowChat(serializers.Serializer):
|
||||
work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流"))
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
|
||||
def open(self):
|
||||
self.is_valid(raise_exception=True)
|
||||
|
|
@ -269,7 +268,8 @@ class ChatSerializers(serializers.Serializer):
|
|||
dataset_setting={},
|
||||
model_setting={},
|
||||
problem_optimization=None,
|
||||
type=ApplicationTypeChoices.WORK_FLOW
|
||||
type=ApplicationTypeChoices.WORK_FLOW,
|
||||
user_id=self.data.get('user_id')
|
||||
)
|
||||
work_flow_version = WorkFlowVersion(work_flow=work_flow)
|
||||
chat_cache.set(chat_id,
|
||||
|
|
@ -332,7 +332,8 @@ class ChatSerializers(serializers.Serializer):
|
|||
application = Application(id=None, dialogue_number=3, model=model,
|
||||
dataset_setting=self.data.get('dataset_setting'),
|
||||
model_setting=self.data.get('model_setting'),
|
||||
problem_optimization=self.data.get('problem_optimization'))
|
||||
problem_optimization=self.data.get('problem_optimization'),
|
||||
user_id=user_id)
|
||||
chat_cache.set(chat_id,
|
||||
ChatInfo(chat_id, chat_model, dataset_id_list,
|
||||
[str(document.id) for document in
|
||||
|
|
@ -533,9 +534,10 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||
raise AppApiException(500, "文档id不正确")
|
||||
|
||||
@staticmethod
|
||||
def post_embedding_paragraph(chat_record, paragraph_id):
|
||||
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
# 发送向量化事件
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id)
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
|
||||
return chat_record
|
||||
|
||||
@post(post_function=post_embedding_paragraph)
|
||||
|
|
@ -573,7 +575,7 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||
chat_record.improve_paragraph_id_list.append(paragraph.id)
|
||||
# 添加标注
|
||||
chat_record.save()
|
||||
return ChatRecordSerializerModel(chat_record).data, paragraph.id
|
||||
return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
||||
|
|
|
|||
|
|
@ -67,6 +67,20 @@ class ApplicationApi(ApiMixin):
|
|||
}
|
||||
)
|
||||
|
||||
class Model(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
return [openapi.Parameter(name='application_id',
|
||||
in_=openapi.IN_PATH,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=True,
|
||||
description='应用id'),
|
||||
openapi.Parameter(name='model_type', in_=openapi.IN_QUERY,
|
||||
type=openapi.TYPE_STRING,
|
||||
required=False,
|
||||
description='模型类型'),
|
||||
]
|
||||
|
||||
class ApiKey(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ class Application(APIView):
|
|||
@swagger_auto_schema(operation_summary="获取模型列表",
|
||||
operation_id="获取模型列表",
|
||||
tags=["应用"],
|
||||
manual_parameters=ApplicationApi.ApiKey.get_request_params_api())
|
||||
manual_parameters=ApplicationApi.Model.get_request_params_api())
|
||||
@has_permissions(ViewPermission(
|
||||
[RoleConstants.ADMIN, RoleConstants.USER],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
|
|
@ -185,7 +185,7 @@ class Application(APIView):
|
|||
return result.success(
|
||||
ApplicationSerializer.Operate(
|
||||
data={'application_id': application_id,
|
||||
'user_id': request.user.id}).list_model())
|
||||
'user_id': request.user.id}).list_model(request.query_params.get('model_type')))
|
||||
|
||||
class Profile(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
|
|
|||
|
|
@ -41,3 +41,7 @@ class MemCache(LocMemCache):
|
|||
delete_keys.append(key)
|
||||
for key in delete_keys:
|
||||
self._delete(key)
|
||||
|
||||
def clear_timeout_data(self):
|
||||
for key in self._cache.keys():
|
||||
self.get(key)
|
||||
|
|
|
|||
|
|
@ -6,33 +6,36 @@
|
|||
@date:2023/10/23 16:03
|
||||
@desc:
|
||||
"""
|
||||
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
|
||||
import time
|
||||
|
||||
from smartdoc.const import CONFIG
|
||||
from common.cache.mem_cache import MemCache
|
||||
|
||||
|
||||
class EmbeddingModel:
|
||||
instance = None
|
||||
class ModelManage:
|
||||
cache = MemCache('model', {})
|
||||
up_clear_time = time.time()
|
||||
|
||||
@staticmethod
|
||||
def get_embedding_model():
|
||||
"""
|
||||
获取向量化模型
|
||||
:return:
|
||||
"""
|
||||
if EmbeddingModel.instance is None:
|
||||
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
|
||||
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
|
||||
device = CONFIG.get('EMBEDDING_DEVICE')
|
||||
encode_kwargs = {'normalize_embeddings': True}
|
||||
e = HuggingFaceEmbeddings(
|
||||
model_name=model_name,
|
||||
cache_folder=cache_folder,
|
||||
model_kwargs={'device': device},
|
||||
encode_kwargs=encode_kwargs,
|
||||
)
|
||||
EmbeddingModel.instance = e
|
||||
return EmbeddingModel.instance
|
||||
def get_model(_id, get_model):
|
||||
model_instance = ModelManage.cache.get(_id)
|
||||
if model_instance is None:
|
||||
model_instance = get_model(_id)
|
||||
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
|
||||
return model_instance
|
||||
# 续期
|
||||
ModelManage.cache.touch(_id, timeout=60 * 30)
|
||||
ModelManage.clear_timeout_cache()
|
||||
return model_instance
|
||||
|
||||
@staticmethod
|
||||
def clear_timeout_cache():
|
||||
if time.time() - ModelManage.up_clear_time > 60:
|
||||
ModelManage.cache.clear_timeout_data()
|
||||
|
||||
@staticmethod
|
||||
def delete_key(_id):
|
||||
if ModelManage.cache.has_key(_id):
|
||||
ModelManage.cache.delete(_id)
|
||||
|
||||
|
||||
class VectorStore:
|
||||
|
|
|
|||
|
|
@ -14,14 +14,14 @@ embedding_thread_pool = ThreadPoolExecutor(3)
|
|||
|
||||
|
||||
def poxy(poxy_function):
|
||||
def inner(args):
|
||||
work_thread_pool.submit(poxy_function, args)
|
||||
def inner(args, **keywords):
|
||||
work_thread_pool.submit(poxy_function, args, **keywords)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def embedding_poxy(poxy_function):
|
||||
def inner(args):
|
||||
embedding_thread_pool.submit(poxy_function, args)
|
||||
def inner(args, **keywords):
|
||||
embedding_thread_pool.submit(poxy_function, args, **keywords)
|
||||
|
||||
return inner
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ from typing import List
|
|||
import django.db.models
|
||||
from blinker import signal
|
||||
from django.db.models import QuerySet
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.db.search import native_search, get_dynamics_model
|
||||
from common.event.common import poxy, embedding_poxy
|
||||
from common.util.file_util import get_file_content
|
||||
|
|
@ -46,22 +47,26 @@ class SyncWebDocumentArgs:
|
|||
|
||||
|
||||
class UpdateProblemArgs:
|
||||
def __init__(self, problem_id: str, problem_content: str):
|
||||
def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings):
|
||||
self.problem_id = problem_id
|
||||
self.problem_content = problem_content
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
|
||||
class UpdateEmbeddingDatasetIdArgs:
|
||||
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str):
|
||||
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str, target_embedding_model: Embeddings):
|
||||
self.paragraph_id_list = paragraph_id_list
|
||||
self.target_dataset_id = target_dataset_id
|
||||
self.target_embedding_model = target_embedding_model
|
||||
|
||||
|
||||
class UpdateEmbeddingDocumentIdArgs:
|
||||
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str):
|
||||
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str,
|
||||
target_embedding_model: Embeddings = None):
|
||||
self.paragraph_id_list = paragraph_id_list
|
||||
self.target_document_id = target_document_id
|
||||
self.target_dataset_id = target_dataset_id
|
||||
self.target_embedding_model = target_embedding_model
|
||||
|
||||
|
||||
class ListenerManagement:
|
||||
|
|
@ -84,16 +89,46 @@ class ListenerManagement:
|
|||
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_problem(args):
|
||||
VectorStore.get_embedding_vector().save(**args)
|
||||
def embedding_by_problem(args, embedding_model: Embeddings):
|
||||
VectorStore.get_embedding_vector().save(**args, embedding=embedding_model)
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings):
|
||||
try:
|
||||
data_list = native_search(
|
||||
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
|
||||
**{'paragraph.id__in': paragraph_id_list}),
|
||||
'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
|
||||
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list,
|
||||
embedding_model=embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'查询向量数据:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
|
||||
|
||||
@staticmethod
|
||||
@embedding_poxy
|
||||
def embedding_by_paragraph(paragraph_id):
|
||||
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings):
|
||||
max_kb.info(f'开始--->向量化段落:{paragraph_id_list}')
|
||||
try:
|
||||
# 删除段落
|
||||
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list)
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
|
||||
status = Status.error
|
||||
finally:
|
||||
QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status})
|
||||
max_kb.info(f'结束--->向量化段落:{paragraph_id_list}')
|
||||
|
||||
@staticmethod
|
||||
@embedding_poxy
|
||||
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
|
||||
"""
|
||||
向量化段落 根据段落id
|
||||
:param paragraph_id: 段落id
|
||||
:return: None
|
||||
@param paragraph_id: 段落id
|
||||
@param embedding_model: 向量模型
|
||||
"""
|
||||
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
|
||||
status = Status.success
|
||||
|
|
@ -107,7 +142,7 @@ class ListenerManagement:
|
|||
# 删除段落
|
||||
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list)
|
||||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||
status = Status.error
|
||||
|
|
@ -117,12 +152,15 @@ class ListenerManagement:
|
|||
|
||||
@staticmethod
|
||||
@embedding_poxy
|
||||
def embedding_by_document(document_id):
|
||||
def embedding_by_document(document_id, embedding_model: Embeddings):
|
||||
"""
|
||||
向量化文档
|
||||
:param document_id: 文档id
|
||||
@param document_id: 文档id
|
||||
@param embedding_model 向量模型
|
||||
:return: None
|
||||
"""
|
||||
if not try_lock('embedding' + str(document_id)):
|
||||
return
|
||||
max_kb.info(f"开始--->向量化文档:{document_id}")
|
||||
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding})
|
||||
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding})
|
||||
|
|
@ -138,7 +176,7 @@ class ListenerManagement:
|
|||
# 删除文档向量数据
|
||||
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
|
||||
# 批量向量化
|
||||
VectorStore.get_embedding_vector().batch_save(data_list)
|
||||
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||
status = Status.error
|
||||
|
|
@ -148,21 +186,24 @@ class ListenerManagement:
|
|||
**{'status': status, 'update_time': datetime.datetime.now()})
|
||||
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
|
||||
max_kb.info(f"结束--->向量化文档:{document_id}")
|
||||
un_lock('embedding' + str(document_id))
|
||||
|
||||
@staticmethod
|
||||
@embedding_poxy
|
||||
def embedding_by_dataset(dataset_id):
|
||||
def embedding_by_dataset(dataset_id, embedding_model: Embeddings):
|
||||
"""
|
||||
向量化知识库
|
||||
:param dataset_id: 知识库id
|
||||
@param dataset_id: 知识库id
|
||||
@param embedding_model 向量模型
|
||||
:return: None
|
||||
"""
|
||||
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
|
||||
try:
|
||||
ListenerManagement.delete_embedding_by_dataset(dataset_id)
|
||||
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
|
||||
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
|
||||
for document in document_list:
|
||||
ListenerManagement.embedding_by_document(document.id)
|
||||
ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||
finally:
|
||||
|
|
@ -224,14 +265,22 @@ class ListenerManagement:
|
|||
|
||||
@staticmethod
|
||||
def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||
{'dataset_id': args.target_dataset_id})
|
||||
if args.target_embedding_model is None:
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||
{'dataset_id': args.target_dataset_id})
|
||||
else:
|
||||
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
|
||||
embedding_model=args.target_embedding_model)
|
||||
|
||||
@staticmethod
|
||||
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||
{'document_id': args.target_document_id,
|
||||
'dataset_id': args.target_dataset_id})
|
||||
if args.target_embedding_model is None:
|
||||
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
|
||||
{'document_id': args.target_document_id,
|
||||
'dataset_id': args.target_dataset_id})
|
||||
else:
|
||||
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
|
||||
embedding_model=args.target_embedding_model)
|
||||
|
||||
@staticmethod
|
||||
def delete_embedding_by_source_ids(source_ids: List[str]):
|
||||
|
|
@ -245,11 +294,6 @@ class ListenerManagement:
|
|||
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
|
||||
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def init_embedding_model(ages):
|
||||
EmbeddingModel.get_embedding_model()
|
||||
|
||||
def run(self):
|
||||
# 添加向量 根据问题id
|
||||
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
|
||||
|
|
@ -276,8 +320,7 @@ class ListenerManagement:
|
|||
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
|
||||
# 启动段落向量
|
||||
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
|
||||
# 初始化向量化模型
|
||||
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
|
||||
|
||||
# 同步web站点知识库
|
||||
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
|
||||
# 同步web站点 文档
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
# Generated by Django 4.2.13 on 2024-07-17 13:56
|
||||
|
||||
import dataset.models.data_set
|
||||
from django.db import migrations, models
|
||||
import django.db.models.deletion
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('setting', '0005_model_permission_type'),
|
||||
('dataset', '0005_file'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='dataset',
|
||||
name='embedding_mode',
|
||||
field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'),
|
||||
),
|
||||
]
|
||||
|
|
@ -9,9 +9,11 @@
|
|||
import uuid
|
||||
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.db.sql_execute import select_one
|
||||
from common.mixins.app_model_mixin import AppModelMixin
|
||||
from setting.models import Model
|
||||
from users.models import User
|
||||
|
||||
|
||||
|
|
@ -33,6 +35,10 @@ class HitHandlingMethod(models.TextChoices):
|
|||
directly_return = 'directly_return', '直接返回'
|
||||
|
||||
|
||||
def default_model():
|
||||
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
|
||||
|
||||
|
||||
class DataSet(AppModelMixin):
|
||||
"""
|
||||
数据集表
|
||||
|
|
@ -43,7 +49,8 @@ class DataSet(AppModelMixin):
|
|||
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
|
||||
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
|
||||
default=Type.base)
|
||||
|
||||
embedding_mode = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
|
||||
default=default_model)
|
||||
meta = models.JSONField(verbose_name="元数据", default=dict)
|
||||
|
||||
class Meta:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from django.db.models import QuerySet
|
|||
from drf_yasg import openapi
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.config.embedding_config import ModelManage
|
||||
from common.db.search import native_search
|
||||
from common.db.sql_execute import update_execute
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
|
@ -21,7 +22,8 @@ from common.mixins.api_mixin import ApiMixin
|
|||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.fork import Fork
|
||||
from dataset.models import Paragraph, Problem, ProblemParagraphMapping
|
||||
from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet
|
||||
from setting.models_provider import get_model
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -130,3 +132,22 @@ class ProblemParagraphManage:
|
|||
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
|
||||
is_create], problem_paragraph_mapping_list
|
||||
return result
|
||||
|
||||
|
||||
def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
|
||||
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
|
||||
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
|
||||
raise Exception("知识库未向量模型不一致")
|
||||
if len(dataset_list) == 0:
|
||||
raise Exception("知识库设置错误,请重新设置知识库")
|
||||
return ModelManage.get_model(str(dataset_list[0].id),
|
||||
lambda _id: get_model(dataset_list[0].embedding_mode))
|
||||
|
||||
|
||||
def get_embedding_model_by_dataset_id(dataset_id: str):
|
||||
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
|
||||
return ModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
|
||||
|
||||
|
||||
def get_embedding_model_by_dataset(dataset):
|
||||
return ModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from functools import reduce
|
|||
from typing import Dict, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.postgres.fields import ArrayField
|
||||
from django.core import validators
|
||||
from django.db import transaction, models
|
||||
|
|
@ -25,7 +24,7 @@ from drf_yasg import openapi
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.models import ApplicationDatasetMapping
|
||||
from common.config.embedding_config import VectorStore, EmbeddingModel
|
||||
from common.config.embedding_config import VectorStore
|
||||
from common.db.search import get_dynamics_model, native_page_search, native_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.event import ListenerManagement, SyncWebDatasetArgs
|
||||
|
|
@ -37,7 +36,8 @@ from common.util.file_util import get_file_content
|
|||
from common.util.fork import ChildLink, Fork
|
||||
from common.util.split_model import get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage
|
||||
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
|
||||
get_embedding_model_by_dataset_id
|
||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||
from embedding.models import SearchMode
|
||||
from setting.models import AuthOperate
|
||||
|
|
@ -206,6 +206,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
max_length=256,
|
||||
min_length=1)
|
||||
|
||||
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
|
||||
|
||||
documents = DocumentInstanceSerializer(required=False, many=True)
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
|
|
@ -226,6 +228,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
max_length=256,
|
||||
min_length=1)
|
||||
|
||||
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
|
||||
|
||||
file_list = serializers.ListSerializer(required=True,
|
||||
error_messages=ErrMessage.list("文件列表"),
|
||||
child=serializers.FileField(required=True,
|
||||
|
|
@ -296,6 +300,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
min_length=1)
|
||||
source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("Web 根地址"), )
|
||||
|
||||
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
|
||||
|
||||
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.char("选择器"))
|
||||
|
||||
|
|
@ -347,6 +353,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
||||
'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title="向量模型id",
|
||||
description="向量模型id"),
|
||||
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url",
|
||||
description="web站点url"),
|
||||
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
|
||||
|
|
@ -355,8 +363,9 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
|
||||
@staticmethod
|
||||
def post_embedding_dataset(document_list, dataset_id):
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
# 发送向量化事件
|
||||
ListenerManagement.embedding_by_dataset_signal.send(dataset_id)
|
||||
ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model)
|
||||
return document_list
|
||||
|
||||
def save_qa(self, instance: Dict, with_valid=True):
|
||||
|
|
@ -365,7 +374,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
self.CreateQASerializers(data=instance).is_valid()
|
||||
file_list = instance.get('file_list')
|
||||
document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list])
|
||||
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list}
|
||||
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list,
|
||||
'embedding_mode_id': instance.get('embedding_mode_id')}
|
||||
return self.save(dataset_instance, with_valid=True)
|
||||
|
||||
@valid_license(model=DataSet, count=50,
|
||||
|
|
@ -381,7 +391,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists():
|
||||
raise AppApiException(500, "知识库名称重复!")
|
||||
dataset = DataSet(
|
||||
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id})
|
||||
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
|
||||
'embedding_mode_id': instance.get('embedding_mode_id')})
|
||||
|
||||
document_model_list = []
|
||||
paragraph_model_list = []
|
||||
|
|
@ -452,7 +463,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
dataset = DataSet(
|
||||
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
|
||||
'type': Type.web,
|
||||
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
|
||||
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector'),
|
||||
'embedding_mode_id': instance.get('embedding_mode_id')}})
|
||||
dataset.save()
|
||||
ListenerManagement.sync_web_dataset_signal.send(
|
||||
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),
|
||||
|
|
@ -500,6 +512,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
properties={
|
||||
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
|
||||
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
|
||||
'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title='向量模型',
|
||||
description='向量模型'),
|
||||
'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据",
|
||||
items=DocumentSerializers().Create.get_request_body_api()
|
||||
)
|
||||
|
|
@ -557,12 +571,13 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
QuerySet(Document).filter(
|
||||
dataset_id=self.data.get('id'),
|
||||
is_active=False)]
|
||||
model = get_embedding_model_by_dataset_id(self.data.get('id'))
|
||||
# 向量库检索
|
||||
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
|
||||
self.data.get('top_number'),
|
||||
self.data.get('similarity'),
|
||||
SearchMode(self.data.get('search_mode')),
|
||||
EmbeddingModel.get_embedding_model())
|
||||
model)
|
||||
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
|
||||
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
|
||||
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
|
||||
|
|
@ -730,7 +745,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
def re_embedding(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'))
|
||||
model = get_embedding_model_by_dataset_id(self.data.get('id'))
|
||||
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model)
|
||||
|
||||
def list_application(self, with_valid=True):
|
||||
if with_valid:
|
||||
|
|
@ -769,6 +785,7 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
QuerySet(ApplicationDatasetMapping).filter(
|
||||
dataset_id=self.data.get('id'))]))}
|
||||
|
||||
@transaction.atomic
|
||||
def edit(self, dataset: Dict, user_id: str):
|
||||
"""
|
||||
修改知识库
|
||||
|
|
@ -782,6 +799,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
raise AppApiException(500, "知识库名称重复!")
|
||||
_dataset = QuerySet(DataSet).get(id=self.data.get("id"))
|
||||
DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset)
|
||||
if 'embedding_mode_id' in dataset:
|
||||
_dataset.embedding_mode_id = dataset.get('embedding_mode_id')
|
||||
if "name" in dataset:
|
||||
_dataset.name = dataset.get("name")
|
||||
if 'desc' in dataset:
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ from common.util.file_util import get_file_content
|
|||
from common.util.fork import Fork
|
||||
from common.util.split_model import get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
|
||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage
|
||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
|
||||
get_embedding_model_by_dataset_id
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -234,12 +235,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
meta={})
|
||||
else:
|
||||
document_list.update(dataset_id=target_dataset_id)
|
||||
# 修改向量信息
|
||||
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
|
||||
[paragraph.id for paragraph in paragraph_list],
|
||||
target_dataset_id))
|
||||
model = None
|
||||
if dataset.embedding_mode_id != target_dataset.embedding_mode_id:
|
||||
model = get_embedding_model_by_dataset_id(target_dataset_id)
|
||||
|
||||
pid_list = [paragraph.id for paragraph in paragraph_list]
|
||||
# 修改段落信息
|
||||
paragraph_list.update(dataset_id=target_dataset_id)
|
||||
# 修改向量信息
|
||||
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
|
||||
pid_list,
|
||||
target_dataset_id, model))
|
||||
|
||||
@staticmethod
|
||||
def get_target_dataset_problem(target_dataset_id: str,
|
||||
|
|
@ -392,7 +398,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
problem_paragraph_mapping_list) > 0 else None
|
||||
# 向量化
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
||||
model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id)
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||
else:
|
||||
document.status = Status.error
|
||||
document.save()
|
||||
|
|
@ -405,6 +412,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
class Operate(ApiMixin, serializers.Serializer):
|
||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
||||
"文档id"))
|
||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id"))
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
|
|
@ -530,7 +538,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
document_id = self.data.get("document_id")
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
||||
model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id'))
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||
|
||||
@transaction.atomic
|
||||
def delete(self):
|
||||
|
|
@ -599,8 +608,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
def post_embedding(result, document_id):
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
||||
def post_embedding(result, document_id, dataset_id):
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -646,7 +656,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
document_id = str(document_model.id)
|
||||
return DocumentSerializers.Operate(
|
||||
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||
with_valid=True), document_id
|
||||
with_valid=True), document_id, dataset_id
|
||||
|
||||
@staticmethod
|
||||
def get_sync_handler(dataset_id):
|
||||
|
|
@ -803,9 +813,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
|
||||
|
||||
@staticmethod
|
||||
def post_embedding(document_list):
|
||||
def post_embedding(document_list, dataset_id):
|
||||
for document_dict in document_list:
|
||||
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'))
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model)
|
||||
return document_list
|
||||
|
||||
@post(post_function=post_embedding)
|
||||
|
|
@ -846,7 +857,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
return [],
|
||||
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
||||
return native_search(query_set, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')),
|
||||
with_search_one=False), dataset_id
|
||||
|
||||
@staticmethod
|
||||
def _batch_sync(document_id_list: List[str]):
|
||||
|
|
|
|||
|
|
@ -20,9 +20,9 @@ from common.exception.app_exception import AppApiException
|
|||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from common.util.field_message import ErrMessage
|
||||
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
||||
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet
|
||||
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
|
||||
ProblemParagraphManage
|
||||
ProblemParagraphManage, get_embedding_model_by_dataset_id, get_embedding_model_by_dataset
|
||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
||||
from embedding.models import SourceType
|
||||
|
||||
|
|
@ -132,6 +132,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
paragraph_id=self.data.get('paragraph_id'),
|
||||
dataset_id=self.data.get('dataset_id'))
|
||||
problem_paragraph_mapping.save()
|
||||
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||
'is_active': True,
|
||||
|
|
@ -140,7 +141,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
})
|
||||
}, embedding_model=model)
|
||||
|
||||
return ProblemSerializers.Operate(
|
||||
data={'dataset_id': self.data.get('dataset_id'),
|
||||
|
|
@ -227,6 +228,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
problem_id=problem.id)
|
||||
problem_paragraph_mapping.save()
|
||||
if with_embedding:
|
||||
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
|
||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||
'is_active': True,
|
||||
'source_type': SourceType.PROBLEM,
|
||||
|
|
@ -234,7 +236,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
})
|
||||
}, embedding_model=model)
|
||||
|
||||
def un_association(self, with_valid=True):
|
||||
if with_valid:
|
||||
|
|
@ -336,10 +338,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
# 修改mapping
|
||||
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
|
||||
['document_id'])
|
||||
|
||||
# 修改向量段落信息
|
||||
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
||||
[paragraph.id for paragraph in paragraph_list],
|
||||
target_document_id, target_dataset_id))
|
||||
target_document_id, target_dataset_id, target_embedding_model=None))
|
||||
# 修改段落信息
|
||||
paragraph_list.update(document_id=target_document_id)
|
||||
# 不同数据集迁移
|
||||
|
|
@ -366,12 +369,19 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
# 修改mapping
|
||||
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
|
||||
['problem_id', 'dataset_id', 'document_id'])
|
||||
# 修改向量段落信息
|
||||
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
||||
[paragraph.id for paragraph in paragraph_list],
|
||||
target_document_id, target_dataset_id))
|
||||
target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first()
|
||||
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
|
||||
embedding_model = None
|
||||
if target_dataset.embedding_mode_id != dataset.embedding_mode_id:
|
||||
embedding_model = get_embedding_model_by_dataset(target_dataset)
|
||||
pid_list = [paragraph.id for paragraph in paragraph_list]
|
||||
# 修改段落信息
|
||||
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)
|
||||
# 修改向量段落信息
|
||||
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
|
||||
pid_list,
|
||||
target_document_id, target_dataset_id, target_embedding_model=embedding_model))
|
||||
|
||||
update_document_char_length(document_id)
|
||||
update_document_char_length(target_document_id)
|
||||
|
||||
|
|
@ -454,13 +464,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
raise AppApiException(500, "段落id不存在")
|
||||
|
||||
@staticmethod
|
||||
def post_embedding(paragraph, instance):
|
||||
def post_embedding(paragraph, instance, dataset_id):
|
||||
if 'is_active' in instance and instance.get('is_active') is not None:
|
||||
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
||||
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
||||
s.send(paragraph.get('id'))
|
||||
else:
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'))
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model)
|
||||
return paragraph
|
||||
|
||||
@post(post_embedding)
|
||||
|
|
@ -508,7 +519,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
|
||||
_paragraph.save()
|
||||
update_document_char_length(self.data.get('document_id'))
|
||||
return self.one(), instance
|
||||
return self.one(), instance, self.data.get('dataset_id')
|
||||
|
||||
def get_problem_list(self):
|
||||
ProblemParagraphMapping(ProblemParagraphMapping)
|
||||
|
|
@ -582,7 +593,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
# 修改长度
|
||||
update_document_char_length(document_id)
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model)
|
||||
return ParagraphSerializers.Operate(
|
||||
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||
with_valid=True)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,8 @@ from common.event import ListenerManagement, UpdateProblemArgs
|
|||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models import Problem, Paragraph, ProblemParagraphMapping
|
||||
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
|
||||
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
|
|
@ -157,6 +158,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
|
|||
content = instance.get('content')
|
||||
problem = QuerySet(Problem).filter(id=problem_id,
|
||||
dataset_id=dataset_id).first()
|
||||
QuerySet(DataSet).filter(id=dataset_id)
|
||||
problem.content = content
|
||||
problem.save()
|
||||
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content))
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content, model))
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ class Dataset(APIView):
|
|||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="创建QA知识库",
|
||||
operation_id="创建QA知识库",
|
||||
|
||||
manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(),
|
||||
responses=get_api_response(
|
||||
DataSetSerializers.Create.CreateQASerializers.get_response_body_api()),
|
||||
|
|
|
|||
|
|
@ -10,9 +10,8 @@ import threading
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.config.embedding_config import EmbeddingModel
|
||||
from common.util.common import sub_array
|
||||
from embedding.models import SourceType, SearchMode
|
||||
|
||||
|
|
@ -51,7 +50,7 @@ class BaseVectorStore(ABC):
|
|||
|
||||
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding=None):
|
||||
embedding: Embeddings):
|
||||
"""
|
||||
插入向量数据
|
||||
:param source_id: 资源id
|
||||
|
|
@ -64,13 +63,10 @@ class BaseVectorStore(ABC):
|
|||
:param paragraph_id 段落id
|
||||
:return: bool
|
||||
"""
|
||||
|
||||
if embedding is None:
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
self.save_pre_handler()
|
||||
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
|
||||
|
||||
def batch_save(self, data_list: List[Dict], embedding=None):
|
||||
def batch_save(self, data_list: List[Dict], embedding: Embeddings):
|
||||
# 获取锁
|
||||
lock.acquire()
|
||||
try:
|
||||
|
|
@ -80,8 +76,6 @@ class BaseVectorStore(ABC):
|
|||
:param embedding: 向量化处理器
|
||||
:return: bool
|
||||
"""
|
||||
if embedding is None:
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
self.save_pre_handler()
|
||||
result = sub_array(data_list)
|
||||
for child_array in result:
|
||||
|
|
@ -94,17 +88,17 @@ class BaseVectorStore(ABC):
|
|||
@abstractmethod
|
||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
embedding: Embeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
||||
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
|
||||
pass
|
||||
|
||||
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
|
||||
exclude_paragraph_list: list[str],
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
embedding: Embeddings):
|
||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||
return []
|
||||
embedding_query = embedding.embed_query(query_text)
|
||||
|
|
@ -123,7 +117,7 @@ class BaseVectorStore(ABC):
|
|||
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||
similarity: float,
|
||||
search_mode: SearchMode,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
embedding: Embeddings):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
@ -142,14 +136,6 @@ class BaseVectorStore(ABC):
|
|||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_documents(self, text_list: List[str]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embed_query(self, text: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete_by_dataset_id(self, dataset_id: str):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -13,9 +13,8 @@ from abc import ABC, abstractmethod
|
|||
from typing import Dict, List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.config.embedding_config import EmbeddingModel
|
||||
from common.db.search import generate_sql_by_query_dict
|
||||
from common.db.sql_execute import select_list
|
||||
from common.util.file_util import get_file_content
|
||||
|
|
@ -33,14 +32,6 @@ class PGVector(BaseVectorStore):
|
|||
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
|
||||
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
|
||||
|
||||
def embed_documents(self, text_list: List[str]):
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
return embedding.embed_documents(text_list)
|
||||
|
||||
def embed_query(self, text: str):
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
return embedding.embed_query(text)
|
||||
|
||||
def vector_is_create(self) -> bool:
|
||||
# 项目启动默认是创建好的 不需要再创建
|
||||
return True
|
||||
|
|
@ -50,7 +41,7 @@ class PGVector(BaseVectorStore):
|
|||
|
||||
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
|
||||
is_active: bool,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
embedding: Embeddings):
|
||||
text_embedding = embedding.embed_query(text)
|
||||
embedding = Embedding(id=uuid.uuid1(),
|
||||
dataset_id=dataset_id,
|
||||
|
|
@ -64,7 +55,7 @@ class PGVector(BaseVectorStore):
|
|||
embedding.save()
|
||||
return True
|
||||
|
||||
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
|
||||
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
|
||||
texts = [row.get('text') for row in text_list]
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
embedding_list = [Embedding(id=uuid.uuid1(),
|
||||
|
|
@ -83,7 +74,7 @@ class PGVector(BaseVectorStore):
|
|||
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
|
||||
similarity: float,
|
||||
search_mode: SearchMode,
|
||||
embedding: HuggingFaceEmbeddings):
|
||||
embedding: Embeddings):
|
||||
if dataset_id_list is None or len(dataset_id_list) == 0:
|
||||
return []
|
||||
exclude_dict = {}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
# Generated by Django 4.2.13 on 2024-07-15 15:23
|
||||
import json
|
||||
|
||||
from django.db import migrations, models
|
||||
from django.db.models import QuerySet
|
||||
|
||||
from common.util.rsa_util import rsa_long_encrypt
|
||||
from setting.models import Status, PermissionType
|
||||
from smartdoc.const import CONFIG
|
||||
|
||||
default_embedding_model_id = '42f63a3d-427e-11ef-b3ec-a8a1595801ab'
|
||||
|
||||
|
||||
def save_default_embedding_model(apps, schema_editor):
|
||||
ModelModel = apps.get_model('setting', 'Model')
|
||||
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
|
||||
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
|
||||
credential = {'cache_folder': cache_folder}
|
||||
model_credential_str = json.dumps(credential)
|
||||
model = ModelModel(id=default_embedding_model_id, name='maxkb-embedding', status=Status.SUCCESS,
|
||||
model_type="EMBEDDING", model_name=model_name, user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab',
|
||||
provider='model_local_provider',
|
||||
credential=rsa_long_encrypt(model_credential_str), meta={},
|
||||
permission_type=PermissionType.PUBLIC)
|
||||
model.save()
|
||||
|
||||
|
||||
def reverse_code_embedding_model(apps, schema_editor):
|
||||
ModelModel = apps.get_model('setting', 'Model')
|
||||
QuerySet(ModelModel).filter(id=default_embedding_model_id).delete()
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
('setting', '0004_alter_model_credential'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='model',
|
||||
name='permission_type',
|
||||
field=models.CharField(choices=[('PUBLIC', '公开'), ('PRIVATE', '私有')], default='PRIVATE', max_length=20,
|
||||
verbose_name='权限类型'),
|
||||
),
|
||||
migrations.RunPython(save_default_embedding_model, reverse_code_embedding_model)
|
||||
]
|
||||
|
|
@ -22,6 +22,13 @@ class Status(models.TextChoices):
|
|||
|
||||
DOWNLOAD = "DOWNLOAD", '下载中'
|
||||
|
||||
PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载'
|
||||
|
||||
|
||||
class PermissionType(models.TextChoices):
|
||||
PUBLIC = "PUBLIC", '公开'
|
||||
PRIVATE = "PRIVATE", "私有"
|
||||
|
||||
|
||||
class Model(AppModelMixin):
|
||||
"""
|
||||
|
|
@ -46,6 +53,9 @@ class Model(AppModelMixin):
|
|||
|
||||
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
|
||||
|
||||
permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices,
|
||||
default=PermissionType.PRIVATE)
|
||||
|
||||
class Meta:
|
||||
db_table = "model"
|
||||
unique_together = ['name', 'user_id']
|
||||
|
|
|
|||
|
|
@ -6,3 +6,85 @@
|
|||
@date:2023/10/31 17:16
|
||||
@desc:
|
||||
"""
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
|
||||
def get_model_(provider, model_type, model_name, credential):
|
||||
"""
|
||||
获取模型实例
|
||||
@param provider: 供应商
|
||||
@param model_type: 模型类型
|
||||
@param model_name: 模型名称
|
||||
@param credential: 认证信息
|
||||
@return: 模型实例
|
||||
"""
|
||||
model = get_provider(provider).get_model(model_type, model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(credential)),
|
||||
streaming=True)
|
||||
return model
|
||||
|
||||
|
||||
def get_model(model):
|
||||
"""
|
||||
获取模型实例
|
||||
@param model: model 数据库Model实例对象
|
||||
@return: 模型实例
|
||||
"""
|
||||
return get_model_(model.provider, model.model_type, model.model_name, model.credential)
|
||||
|
||||
|
||||
def get_provider(provider):
|
||||
"""
|
||||
获取供应商实例
|
||||
@param provider: 供应商字符串
|
||||
@return: 供应商实例
|
||||
"""
|
||||
return ModelProvideConstants[provider].value
|
||||
|
||||
|
||||
def get_model_list(provider, model_type):
|
||||
"""
|
||||
获取模型列表
|
||||
@param provider: 供应商字符串
|
||||
@param model_type: 模型类型
|
||||
@return: 模型列表
|
||||
"""
|
||||
return get_provider(provider).get_model_list(model_type)
|
||||
|
||||
|
||||
def get_model_credential(provider, model_type, model_name):
|
||||
"""
|
||||
获取模型认证实例
|
||||
@param provider: 供应商字符串
|
||||
@param model_type: 模型类型
|
||||
@param model_name: 模型名称
|
||||
@return: 认证实例对象
|
||||
"""
|
||||
return get_provider(provider).get_model_credential(model_type, model_name)
|
||||
|
||||
|
||||
def get_model_type_list(provider):
|
||||
"""
|
||||
获取模型类型列表
|
||||
@param provider: 供应商字符串
|
||||
@return: 模型类型列表
|
||||
"""
|
||||
return get_provider(provider).get_model_type_list()
|
||||
|
||||
|
||||
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
"""
|
||||
校验模型认证参数
|
||||
@param provider: 供应商字符串
|
||||
@param model_type: 模型类型
|
||||
@param model_name: 模型名称
|
||||
@param model_credential: 模型认证数据
|
||||
@param raise_exception: 是否抛出错误
|
||||
@return: True|False
|
||||
"""
|
||||
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class IModelProvider(ABC):
|
|||
def get_model_list(self, model_type):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return self.get_model_info_manage().get_model_list()
|
||||
return self.get_model_info_manage().get_model_list_by_model_type(model_type)
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
|
||||
|
|
@ -191,6 +191,9 @@ class ModelInfoManage:
|
|||
def get_model_list(self):
|
||||
return [model.to_dict() for model in self.model_list]
|
||||
|
||||
def get_model_list_by_model_type(self, model_type):
|
||||
return [model.to_dict() for model in self.model_list if model.model_type == model_type]
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
|
||||
len([model for model in self.model_list if model.model_type == _type.name]) > 0]
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import
|
|||
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
||||
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
|
||||
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
|
||||
from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
|
||||
|
||||
|
||||
class ModelProvideConstants(Enum):
|
||||
|
|
@ -31,3 +32,4 @@ class ModelProvideConstants(Enum):
|
|||
model_xf_provider = XunFeiModelProvider()
|
||||
model_deepseek_provider = DeepSeekModelProvider()
|
||||
model_gemini_provider = GeminiModelProvider()
|
||||
model_local_provider = LocalModelProvider()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/7/10 17:48
|
||||
@desc:
|
||||
"""
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/11 11:06
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||
|
||||
|
||||
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
if not model_type == 'EMBEDDING':
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
for key in ['cache_folder']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query('你好')
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return model
|
||||
|
||||
cache_folder = forms.TextInputField('模型目录', required=True)
|
||||
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" t="1720668342208" class="icon" viewBox="0 0 1024 1024" version="1.1" p-id="9052" width="100%" height="100%"><path d="M512.2 475.7c-8.2-0.3-16.1-2.4-23.4-6.1L192.6 330.2c-24.9-11.1-36.1-40.3-25-65.2 5-11.2 13.9-20.1 25-25l281.5-133.2a89.43 89.43 0 0 1 76 0L831.7 240c24.9 11.1 36.1 40.3 25 65.2-5 11.2-13.9 20.1-25 25L535.5 469.5c-7.2 3.8-15.2 5.9-23.3 6.2z m-76.5 452.5c-7.6 0-15.1-1.9-21.8-5.5L146.3 797.2c-17-8.9-27.5-26.5-27.3-45.6v-320c0.1-18 9.7-34.5 25.1-43.7 14.3-8.1 31.8-8.1 46.1 0l267.1 125.4c16.1 8.7 26.4 25.4 27.1 43.7v320.4c-0.2 17.9-9.6 34.4-24.9 43.7-7.2 4.3-15.4 6.8-23.8 7.1z m152.9 0c-8.3 0-16.5-2.2-23.8-6.3-15.3-9.3-24.7-25.8-24.9-43.7V556.9c0.4-18.2 10.4-34.8 26.2-43.7L835 387c14.2-7.5 31.4-7.1 45.2 1.1 15.5 9.1 25 25.7 25.1 43.7v319.8c0.4 18.9-9.7 36.5-26.2 45.6L610.5 922.8c-6.8 3.6-14.3 5.5-21.9 5.4z" p-id="9053"/></svg>
|
||||
|
After Width: | Height: | Size: 931 B |
|
|
@ -0,0 +1,38 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: zhipu_model_provider.py
|
||||
@date:2024/04/19 13:5
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
|
||||
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
|
||||
LocalEmbeddingCredential(), LocalEmbedding)
|
||||
|
||||
model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese)
|
||||
.append_default_model_info(embedding_text2vec_base_chinese)
|
||||
.build())
|
||||
|
||||
|
||||
class LocalModelProvider(IModelProvider):
|
||||
|
||||
def get_model_info_manage(self):
|
||||
return model_info_manage
|
||||
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_local_provider', name='本地模型', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'local_model_provider', 'icon',
|
||||
'local_icon_svg')))
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/11 14:06
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_huggingface import HuggingFaceEmbeddings
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
|
||||
model_kwargs={'device': model_credential.get('device')},
|
||||
encode_kwargs={'normalize_embeddings': True},
|
||||
)
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 15:10
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||
|
||||
|
||||
class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
try:
|
||||
model_list = provider.get_base_model_list(model_credential.get('api_base'))
|
||||
except Exception as e:
|
||||
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
|
||||
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
|
||||
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
|
||||
if len(exist) == 0:
|
||||
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
|
||||
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query('你好')
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return model_info
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
return self
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 15:02
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain_community.embeddings import OllamaEmbeddings
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return OllamaEmbedding(
|
||||
model=model_name,
|
||||
base_url=model_credential.get('api_base'),
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed documents using an Ollama deployed embedding model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
instruction_pairs = [f"{text}" for text in texts]
|
||||
embeddings = self._embed(instruction_pairs)
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a query using a Ollama deployed embedding model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
instruction_pair = f"{text}"
|
||||
embedding = self._embed([instruction_pair])[0]
|
||||
return embedding
|
||||
|
|
@ -20,7 +20,9 @@ from common.forms import BaseForm
|
|||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage
|
||||
from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential
|
||||
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
|
||||
from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding
|
||||
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -88,14 +90,25 @@ model_info_list = [
|
|||
ModelInfo(
|
||||
'phi3',
|
||||
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
|
||||
]
|
||||
ollama_embedding_model_credential = OllamaEmbeddingModelCredential()
|
||||
embedding_model_info = [
|
||||
ModelInfo(
|
||||
'nomic-embed-text',
|
||||
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
|
||||
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding),
|
||||
]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info_list(
|
||||
embedding_model_info).append_default_model_info(
|
||||
ModelInfo(
|
||||
'phi3',
|
||||
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build()
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).append_default_model_info(ModelInfo(
|
||||
'nomic-embed-text',
|
||||
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
|
||||
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ).build()
|
||||
|
||||
|
||||
def get_base_url(url: str):
|
||||
|
|
@ -139,7 +152,6 @@ def convert(response_stream) -> Iterator[DownModelChunk]:
|
|||
temp = ""
|
||||
|
||||
if len(temp) > 0:
|
||||
print(temp)
|
||||
rows = [t for t in temp.split("\n") if len(t) > 0]
|
||||
for row in rows:
|
||||
yield convert_to_down_model_chunk(row, index)
|
||||
|
|
@ -154,9 +166,6 @@ class OllamaModelProvider(IModelProvider):
|
|||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
|
||||
'ollama_icon_svg')))
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 2
|
||||
|
||||
@staticmethod
|
||||
def get_base_model_list(api_base):
|
||||
base_url = get_base_url(api_base)
|
||||
|
|
@ -165,7 +174,7 @@ class OllamaModelProvider(IModelProvider):
|
|||
return r.json()
|
||||
|
||||
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
|
||||
api_base = model_credential.get('api_base')
|
||||
api_base = model_credential.get('api_base', '')
|
||||
base_url = get_base_url(api_base)
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 16:45
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
|
||||
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=True):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
|
||||
for key in ['api_base', 'api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query('你好')
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
|
||||
|
||||
api_base = forms.TextInputField('API 域名', required=True)
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/12 17:44
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
||||
|
||||
class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return OpenAIEmbeddingModel(
|
||||
api_key=model_credential.get('api_key'),
|
||||
model=model_name,
|
||||
openai_api_base=model_credential.get('api_base'),
|
||||
)
|
||||
|
|
@ -11,7 +11,9 @@ import os
|
|||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
|
||||
ModelTypeConst, ModelInfoManage
|
||||
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
|
||||
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
|
||||
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
|
||||
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -58,11 +60,17 @@ model_info_list = [
|
|||
ModelTypeConst.LLM, openai_llm_model_credential,
|
||||
OpenAIChatModel)
|
||||
]
|
||||
open_ai_embedding_credential = OpenAIEmbeddingCredential()
|
||||
model_info_embedding_list = [
|
||||
ModelInfo('text-embedding-ada-002', '',
|
||||
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
|
||||
OpenAIEmbeddingModel)]
|
||||
|
||||
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
|
||||
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
|
||||
openai_llm_model_credential, OpenAIChatModel
|
||||
)).build()
|
||||
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
|
||||
model_info_embedding_list[0]).build()
|
||||
|
||||
|
||||
class OpenAIModelProvider(IModelProvider):
|
||||
|
|
|
|||
|
|
@ -7,12 +7,14 @@
|
|||
@desc:
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.core import validators
|
||||
from django.db.models import QuerySet, Q
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.models import Application
|
||||
|
|
@ -36,6 +38,9 @@ class ModelPullManage:
|
|||
for chunk in response:
|
||||
down_model_chunk[chunk.digest] = chunk.to_dict()
|
||||
if time.time() - timestamp > 5:
|
||||
model_new = QuerySet(Model).filter(id=model.id).first()
|
||||
if model_new.status == Status.PAUSE_DOWNLOAD:
|
||||
return
|
||||
QuerySet(Model).filter(id=model.id).update(
|
||||
meta={"down_model_chunk": list(down_model_chunk.values())})
|
||||
timestamp = time.time()
|
||||
|
|
@ -72,7 +77,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
self.is_valid(raise_exception=True)
|
||||
user_id = self.data.get('user_id')
|
||||
name = self.data.get('name')
|
||||
model_query_set = QuerySet(Model).filter(user_id=user_id)
|
||||
model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC')))
|
||||
query_params = {}
|
||||
if name is not None:
|
||||
query_params['name__contains'] = name
|
||||
|
|
@ -85,7 +90,8 @@ class ModelSerializer(serializers.Serializer):
|
|||
|
||||
return [
|
||||
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||
'model_name': model.model_name, 'status': model.status, 'meta': model.meta} for model in
|
||||
'model_name': model.model_name, 'status': model.status, 'meta': model.meta,
|
||||
'permission_type': model.permission_type} for model in
|
||||
model_query_set.filter(**query_params).order_by("-create_time")]
|
||||
|
||||
class Edit(serializers.Serializer):
|
||||
|
|
@ -96,6 +102,11 @@ class ModelSerializer(serializers.Serializer):
|
|||
|
||||
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
|
||||
|
||||
permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"), validators=[
|
||||
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
|
||||
message="权限只支持PUBLIC|PRIVATE", code=500)
|
||||
])
|
||||
|
||||
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
|
||||
|
||||
credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息"))
|
||||
|
|
@ -135,6 +146,11 @@ class ModelSerializer(serializers.Serializer):
|
|||
|
||||
model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型"))
|
||||
|
||||
permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[
|
||||
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
|
||||
message="权限只支持PUBLIC|PRIVATE", code=500)
|
||||
])
|
||||
|
||||
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型"))
|
||||
|
||||
credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息"))
|
||||
|
|
@ -165,10 +181,12 @@ class ModelSerializer(serializers.Serializer):
|
|||
provider = self.data.get('provider')
|
||||
model_type = self.data.get('model_type')
|
||||
model_name = self.data.get('model_name')
|
||||
permission_type = self.data.get('permission_type')
|
||||
model_credential_str = json.dumps(credential)
|
||||
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
|
||||
credential=rsa_long_encrypt(model_credential_str),
|
||||
provider=provider, model_type=model_type, model_name=model_name)
|
||||
provider=provider, model_type=model_type, model_name=model_name,
|
||||
permission_type=permission_type)
|
||||
model.save()
|
||||
if status == Status.DOWNLOAD:
|
||||
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
|
||||
|
|
@ -184,7 +202,8 @@ class ModelSerializer(serializers.Serializer):
|
|||
'meta': model.meta,
|
||||
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
|
||||
model.model_name).encryption_dict(
|
||||
credential)}
|
||||
credential),
|
||||
'permission_type': model.permission_type}
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
||||
|
|
@ -210,7 +229,8 @@ class ModelSerializer(serializers.Serializer):
|
|||
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||
'model_name': model.model_name,
|
||||
'status': model.status,
|
||||
'meta': model.meta, }
|
||||
'meta': model.meta
|
||||
}
|
||||
|
||||
def delete(self, with_valid=True):
|
||||
if with_valid:
|
||||
|
|
@ -221,6 +241,12 @@ class ModelSerializer(serializers.Serializer):
|
|||
QuerySet(Model).filter(id=self.data.get('id')).delete()
|
||||
return True
|
||||
|
||||
def pause_download(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD)
|
||||
return True
|
||||
|
||||
def edit(self, instance: Dict, user_id: str, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
|
|
@ -245,7 +271,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
model.status = Status.DOWNLOAD
|
||||
else:
|
||||
raise e
|
||||
update_keys = ['credential', 'name', 'model_type', 'model_name']
|
||||
update_keys = ['credential', 'name', 'model_type', 'model_name', 'permission_type']
|
||||
for update_key in update_keys:
|
||||
if update_key in instance and instance.get(update_key) is not None:
|
||||
if update_key == 'credential':
|
||||
|
|
|
|||
|
|
@ -74,6 +74,8 @@ class ModelCreateApi(ApiMixin):
|
|||
'provider': openapi.Schema(type=openapi.TYPE_STRING,
|
||||
title="供应商",
|
||||
description="供应商"),
|
||||
'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限",
|
||||
description="PUBLIC|PRIVATE"),
|
||||
'model_type': openapi.Schema(type=openapi.TYPE_STRING,
|
||||
title="供应商",
|
||||
description="供应商"),
|
||||
|
|
@ -82,7 +84,8 @@ class ModelCreateApi(ApiMixin):
|
|||
description="供应商"),
|
||||
'credential': openapi.Schema(type=openapi.TYPE_OBJECT,
|
||||
title="模型证书信息",
|
||||
description="模型证书信息")
|
||||
description="模型证书信息"),
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ urlpatterns = [
|
|||
name="provider/model_form"),
|
||||
path('model', views.Model.as_view(), name='model'),
|
||||
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
|
||||
path('model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
|
||||
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
|
||||
path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting'),
|
||||
path('valid/<str:valid_type>/<int:valid_count>', views.Valid.as_view())
|
||||
|
|
|
|||
|
|
@ -69,6 +69,18 @@ class Model(APIView):
|
|||
return result.success(
|
||||
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True))
|
||||
|
||||
class PauseDownload(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['PUT'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="暂停模型下载",
|
||||
operation_id="暂停模型下载",
|
||||
tags=["模型"])
|
||||
@has_permissions(PermissionConstants.MODEL_CREATE)
|
||||
def put(self, request: Request, model_id: str):
|
||||
return result.success(
|
||||
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).pause_download())
|
||||
|
||||
class Operate(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -104,9 +104,6 @@ CACHES = {
|
|||
"token_cache": {
|
||||
'BACKEND': 'common.cache.file_cache.FileCache',
|
||||
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
|
||||
},
|
||||
"chat_cache": {
|
||||
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -129,11 +129,12 @@ const getDatasetDetail: (dataset_id: string, loading?: Ref<boolean>) => Promise<
|
|||
"desc": true
|
||||
}
|
||||
*/
|
||||
const putDataset: (dataset_id: string, data: any) => Promise<Result<any>> = (
|
||||
dataset_id,
|
||||
data: any
|
||||
) => {
|
||||
return put(`${prefix}/${dataset_id}`, data)
|
||||
const putDataset: (
|
||||
dataset_id: string,
|
||||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (dataset_id, data, loading) => {
|
||||
return put(`${prefix}/${dataset_id}`, data, undefined, loading)
|
||||
}
|
||||
/**
|
||||
* 获取知识库 可关联的应用列表
|
||||
|
|
|
|||
|
|
@ -130,7 +130,18 @@ const getModelMetaById: (model_id: string, loading?: Ref<boolean>) => Promise<Re
|
|||
) => {
|
||||
return get(`${prefix}/${model_id}/meta`, {}, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 暂停下载
|
||||
* @param model_id 模型id
|
||||
* @param loading 加载器
|
||||
* @returns
|
||||
*/
|
||||
const pauseDownload: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
|
||||
model_id,
|
||||
loading
|
||||
) => {
|
||||
return put(`${prefix}/${model_id}/pause_download`, undefined, {}, loading)
|
||||
}
|
||||
const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
|
||||
model_id,
|
||||
loading
|
||||
|
|
@ -147,5 +158,6 @@ export default {
|
|||
updateModel,
|
||||
deleteModel,
|
||||
getModelById,
|
||||
getModelMetaById
|
||||
getModelMetaById,
|
||||
pauseDownload
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ interface datasetData {
|
|||
desc: String
|
||||
documents?: Array<any>
|
||||
type?: String
|
||||
embedding_mode_id?: String
|
||||
}
|
||||
|
||||
export type { datasetData }
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ interface Model {
|
|||
* 模型类型
|
||||
*/
|
||||
model_type: string
|
||||
permission_type: 'PUBLIC' | 'PRIVATE'
|
||||
/**
|
||||
* 基础模型
|
||||
*/
|
||||
|
|
@ -68,7 +69,7 @@ interface Model {
|
|||
/**
|
||||
* 状态
|
||||
*/
|
||||
status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR'
|
||||
status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR' | 'PAUSE_DOWNLOAD'
|
||||
/**
|
||||
* 元数据
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -18,7 +18,8 @@
|
|||
</slot>
|
||||
<slot></slot>
|
||||
</div>
|
||||
<el-checkbox v-bind:modelValue="modelValue.includes(toModelValue)"> </el-checkbox>
|
||||
<el-checkbox v-bind:modelValue="modelValue.includes(toModelValue)" @change="checkboxChange">
|
||||
</el-checkbox>
|
||||
</div>
|
||||
</el-card>
|
||||
</template>
|
||||
|
|
@ -40,7 +41,7 @@ const toModelValue = computed(() => (props.valueField ? props.data[props.valueFi
|
|||
// set: (val) => val
|
||||
// })
|
||||
|
||||
const emit = defineEmits(['update:modelValue'])
|
||||
const emit = defineEmits(['update:modelValue', 'change'])
|
||||
|
||||
const checked = () => {
|
||||
const value = props.modelValue ? props.modelValue : []
|
||||
|
|
@ -53,6 +54,10 @@ const checked = () => {
|
|||
emit('update:modelValue', [...value, toModelValue.value])
|
||||
}
|
||||
}
|
||||
|
||||
function checkboxChange() {
|
||||
emit('change')
|
||||
}
|
||||
</script>
|
||||
<style lang="scss" scoped>
|
||||
.card-checkbox {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
<template>
|
||||
<div class="loading-container loader">
|
||||
<div class="download-loading">
|
||||
<div></div>
|
||||
<div></div>
|
||||
<div></div>
|
||||
<div></div>
|
||||
<div></div>
|
||||
<div></div>
|
||||
<div></div>
|
||||
<div></div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
<script setup lang="ts"></script>
|
||||
<style lang="scss" scoped>
|
||||
.loading-container {
|
||||
display: -webkit-flex; /*safari弹性布局*/
|
||||
justify-content: center;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
@-webkit-keyframes loader {
|
||||
0% {
|
||||
opacity: 0.3;
|
||||
}
|
||||
80% {
|
||||
opacity: 1;
|
||||
}
|
||||
100% {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
.download-loading {
|
||||
position: relative;
|
||||
}
|
||||
.download-loading div {
|
||||
width: 5px;
|
||||
height: 12px;
|
||||
background: var(--el-color-info);
|
||||
position: absolute;
|
||||
border-radius: 2px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
.download-loading div:nth-child(1) {
|
||||
top: -20px;
|
||||
left: 0;
|
||||
-webkit-animation: loader 1s -0.8s infinite ease-in-out;
|
||||
}
|
||||
.download-loading div:nth-child(2) {
|
||||
top: -13px;
|
||||
left: 13px;
|
||||
-webkit-transform: rotate(45deg);
|
||||
-webkit-animation: loader 1s -0.6s infinite ease-in-out;
|
||||
}
|
||||
.download-loading div:nth-child(3) {
|
||||
top: 0px;
|
||||
left: 20px;
|
||||
-webkit-transform: rotate(90deg);
|
||||
-webkit-animation: loader 1s -0.5s infinite ease-in-out;
|
||||
}
|
||||
.download-loading div:nth-child(4) {
|
||||
top: 13px;
|
||||
left: 13px;
|
||||
-webkit-transform: rotate(-45deg);
|
||||
-webkit-animation: loader 1s -0.4s infinite ease-in-out;
|
||||
}
|
||||
.download-loading div:nth-child(5) {
|
||||
top: 20px;
|
||||
left: 0px;
|
||||
-webkit-transform: rotate(0deg);
|
||||
-webkit-animation: loader 1s -0.3s infinite ease-in-out;
|
||||
}
|
||||
.download-loading div:nth-child(6) {
|
||||
top: 13px;
|
||||
left: -13px;
|
||||
-webkit-transform: rotate(45deg);
|
||||
-webkit-animation: loader 1s -0.2s infinite ease-in-out;
|
||||
}
|
||||
.download-loading div:nth-child(7) {
|
||||
top: 0px;
|
||||
left: -20px;
|
||||
-webkit-transform: rotate(90deg);
|
||||
-webkit-animation: loader 1s -0.1s infinite ease-in-out;
|
||||
}
|
||||
.download-loading div:nth-child(8) {
|
||||
top: -13px;
|
||||
left: -13px;
|
||||
-webkit-transform: rotate(-45deg);
|
||||
-webkit-animation: loader 1s 0s infinite ease-in-out;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
export enum PermissionType {
|
||||
PRIVATE = '私有',
|
||||
PUBLIC = '公用'
|
||||
}
|
||||
export enum PermissionDesc {
|
||||
PRIVATE = '仅自己使用',
|
||||
PUBLIC = '所有用户都可使用,不能编辑'
|
||||
}
|
||||
|
|
@ -99,7 +99,7 @@
|
|||
</div>
|
||||
</template>
|
||||
<template v-else-if="isDataset">
|
||||
<div class="w-full text-left cursor" @click="router.push({ path: '/dataset/create' })">
|
||||
<div class="w-full text-left cursor" @click="openCreateDialog">
|
||||
<el-button link>
|
||||
<el-icon class="mr-4"><Plus /></el-icon> 创建知识库
|
||||
</el-button>
|
||||
|
|
@ -110,12 +110,14 @@
|
|||
</el-dropdown>
|
||||
</div>
|
||||
<CreateApplicationDialog ref="CreateApplicationDialogRef" @refresh="refresh" />
|
||||
<CreateDatasetDialog ref="CreateDatasetDialogRef" @refresh="refresh" />
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, computed } from 'vue'
|
||||
import { onBeforeRouteLeave, useRouter, useRoute } from 'vue-router'
|
||||
import CreateApplicationDialog from '@/views/application/component/CreateApplicationDialog.vue'
|
||||
import CreateDatasetDialog from '@/views/dataset/component/CreateDatasetDialog.vue'
|
||||
import { isAppIcon, isWorkFlow } from '@/utils/application'
|
||||
import useStore from '@/stores'
|
||||
const { common, dataset, application } = useStore()
|
||||
|
|
@ -130,6 +132,7 @@ onBeforeRouteLeave((to, from) => {
|
|||
common.saveBreadcrumb(null)
|
||||
})
|
||||
|
||||
const CreateDatasetDialogRef = ref()
|
||||
const CreateApplicationDialogRef = ref()
|
||||
const list = ref<any[]>([])
|
||||
const loading = ref(false)
|
||||
|
|
@ -148,7 +151,11 @@ const isDataset = computed(() => {
|
|||
})
|
||||
|
||||
function openCreateDialog() {
|
||||
CreateApplicationDialogRef.value.open()
|
||||
if (isDataset.value) {
|
||||
CreateDatasetDialogRef.value.open()
|
||||
} else if (isApplication.value) {
|
||||
CreateApplicationDialogRef.value.open()
|
||||
}
|
||||
}
|
||||
|
||||
function changeMenu(id: string) {
|
||||
|
|
|
|||
|
|
@ -13,9 +13,9 @@ const datasetRouter = {
|
|||
},
|
||||
{
|
||||
path: '/dataset/:type', // create 或者 upload
|
||||
name: 'CreateDataset',
|
||||
name: 'UploadDocumentDataset',
|
||||
meta: { activeMenu: '/dataset' },
|
||||
component: () => import('@/views/dataset/CreateDataset.vue'),
|
||||
component: () => import('@/views/dataset/UploadDocumentDataset.vue'),
|
||||
hidden: true
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
import { defineStore } from 'pinia'
|
||||
import modelApi from '@/api/model'
|
||||
import type { modelRequest, Provider } from '@/api/type/model'
|
||||
import type { ListModelRequest, Provider } from '@/api/type/model'
|
||||
const useModelStore = defineStore({
|
||||
id: 'model',
|
||||
state: () => ({}),
|
||||
actions: {
|
||||
async asyncGetModel(data?: modelRequest) {
|
||||
async asyncGetModel(data?: ListModelRequest) {
|
||||
return new Promise((resolve, reject) => {
|
||||
modelApi
|
||||
.getModel(data)
|
||||
|
|
|
|||
|
|
@ -377,6 +377,11 @@ h5 {
|
|||
color: var(--el-color-primary);
|
||||
border: none;
|
||||
}
|
||||
.danger-tag {
|
||||
background: var(--tag-danger-bg);
|
||||
color: #d03f3b;
|
||||
border: none;
|
||||
}
|
||||
.success-tag {
|
||||
background: var(--tag-success-bg);
|
||||
color: var(--el-color-success);
|
||||
|
|
@ -388,6 +393,12 @@ h5 {
|
|||
border: none;
|
||||
}
|
||||
|
||||
.info-tag {
|
||||
background: var(--app-text-color-light-1);
|
||||
color: var(--app-text-color-secondary);
|
||||
border: none;
|
||||
}
|
||||
|
||||
.purple-tag {
|
||||
background: #f2ebfe;
|
||||
color: #7f3bf5;
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@
|
|||
--tag-success-color: #2ca91f;
|
||||
--tag-warning-bg: rgba(255, 136, 0, 0.2);
|
||||
--tag-warning-color: #d97400;
|
||||
--tag-danger-bg: rgba(245, 74, 69, 0.2);
|
||||
|
||||
/** card */
|
||||
--card-width: 330px;
|
||||
|
|
|
|||
|
|
@ -4,21 +4,28 @@
|
|||
v-model="dialogVisible"
|
||||
width="600"
|
||||
append-to-body
|
||||
class="addDataset-dialog"
|
||||
>
|
||||
<template #header="{ titleId, titleClass }">
|
||||
<div class="my-header flex">
|
||||
<div class="flex-between mb-8">
|
||||
<h4 :id="titleId" :class="titleClass">
|
||||
{{ $t('views.application.applicationForm.dialogues.addDataset') }}
|
||||
</h4>
|
||||
<el-button link class="ml-16" @click="refresh">
|
||||
<el-icon class="mr-4"><Refresh /></el-icon
|
||||
>{{ $t('views.application.applicationForm.dialogues.refresh') }}
|
||||
</el-button>
|
||||
<div class="flex align-center">
|
||||
<el-button link class="ml-16" @click="refresh">
|
||||
<el-icon class="mr-4"><Refresh /></el-icon
|
||||
>{{ $t('views.application.applicationForm.dialogues.refresh') }}
|
||||
</el-button>
|
||||
<el-divider direction="vertical" />
|
||||
</div>
|
||||
</div>
|
||||
<el-text type="info" class="color-secondary">
|
||||
所选知识库必须使用相同的 Embedding 模型
|
||||
</el-text>
|
||||
</template>
|
||||
<el-row :gutter="12" v-loading="loading">
|
||||
<el-col :span="12" v-for="(item, index) in data" :key="index" class="mb-16">
|
||||
<CardCheckbox value-field="id" :data="item" v-model="checkList">
|
||||
<el-col :span="12" v-for="(item, index) in filterData" :key="index" class="mb-16">
|
||||
<CardCheckbox value-field="id" :data="item" v-model="checkList" @change="changeHandle">
|
||||
<span class="ellipsis">
|
||||
{{ item.name }}
|
||||
</span>
|
||||
|
|
@ -26,19 +33,29 @@
|
|||
</el-col>
|
||||
</el-row>
|
||||
<template #footer>
|
||||
<span class="dialog-footer">
|
||||
<el-button @click.prevent="dialogVisible = false">
|
||||
{{ $t('views.application.applicationForm.buttons.cancel') }}
|
||||
</el-button>
|
||||
<el-button type="primary" @click="submitHandle">
|
||||
{{ $t('views.application.applicationForm.buttons.confirm') }}
|
||||
</el-button>
|
||||
</span>
|
||||
<div class="flex-between">
|
||||
<div>
|
||||
<el-text type="info" class="color-secondary" v-if="checkList.length > 0">
|
||||
已选 {{ checkList.length }} 个知识库
|
||||
</el-text>
|
||||
<el-button link type="primary" v-if="checkList.length > 0" @click="clearCheck">
|
||||
清空
|
||||
</el-button>
|
||||
</div>
|
||||
<span>
|
||||
<el-button @click.prevent="dialogVisible = false">
|
||||
{{ $t('views.application.applicationForm.buttons.cancel') }}
|
||||
</el-button>
|
||||
<el-button type="primary" @click="submitHandle">
|
||||
{{ $t('views.application.applicationForm.buttons.confirm') }}
|
||||
</el-button>
|
||||
</span>
|
||||
</div>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { ref, watch } from 'vue'
|
||||
import { computed, ref, watch } from 'vue'
|
||||
const props = defineProps({
|
||||
data: {
|
||||
type: Array<any>,
|
||||
|
|
@ -51,6 +68,13 @@ const emit = defineEmits(['addData', 'refresh'])
|
|||
|
||||
const dialogVisible = ref<boolean>(false)
|
||||
const checkList = ref([])
|
||||
const currentEmbedding = ref('')
|
||||
|
||||
const filterData = computed(() => {
|
||||
return currentEmbedding.value
|
||||
? props.data.filter((v) => v.embedding_mode_id === currentEmbedding.value)
|
||||
: props.data
|
||||
})
|
||||
|
||||
watch(dialogVisible, (bool) => {
|
||||
if (!bool) {
|
||||
|
|
@ -58,6 +82,18 @@ watch(dialogVisible, (bool) => {
|
|||
}
|
||||
})
|
||||
|
||||
function changeHandle() {
|
||||
if (checkList.value.length === 1) {
|
||||
currentEmbedding.value = props.data.filter(
|
||||
(v) => v.id === checkList.value[0]
|
||||
)[0].embedding_mode_id
|
||||
}
|
||||
}
|
||||
function clearCheck() {
|
||||
checkList.value = []
|
||||
currentEmbedding.value = ''
|
||||
}
|
||||
|
||||
const open = (checked: any) => {
|
||||
checkList.value = checked
|
||||
dialogVisible.value = true
|
||||
|
|
@ -73,4 +109,13 @@ const refresh = () => {
|
|||
|
||||
defineExpose({ open })
|
||||
</script>
|
||||
<style lang="scss" scope></style>
|
||||
<style lang="scss" scope>
|
||||
.addDataset-dialog {
|
||||
.el-dialog__header.show-close {
|
||||
padding-right: 15px;
|
||||
}
|
||||
.el-dialog__headerbtn {
|
||||
top: 13px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -63,10 +63,10 @@
|
|||
</el-form>
|
||||
<template #footer>
|
||||
<span class="dialog-footer">
|
||||
<el-button @click.prevent="dialogVisible = false">
|
||||
<el-button @click.prevent="dialogVisible = false" :loading="loading">
|
||||
{{ $t('views.application.applicationForm.buttons.cancel') }}
|
||||
</el-button>
|
||||
<el-button type="primary" @click="submitValid(applicationFormRef)">
|
||||
<el-button type="primary" @click="submitValid(applicationFormRef)" :loading="loading">
|
||||
{{ $t('views.application.applicationForm.buttons.create') }}
|
||||
</el-button>
|
||||
</span>
|
||||
|
|
@ -183,10 +183,7 @@ const submitValid = (formEl: FormInstance | undefined) => {
|
|||
if (res?.data) {
|
||||
submitHandle(formEl)
|
||||
} else {
|
||||
MsgAlert(
|
||||
'提示',
|
||||
'社区版最多支持 5 个应用,如需拥有更多应用,请联系我们(https://fit2cloud.com/)。'
|
||||
)
|
||||
MsgAlert('提示', '社区版最多支持 5 个应用,如需拥有更多应用,请升级为专业版。')
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
<template>
|
||||
<div class="authentication-setting p-24">
|
||||
<div class="authentication-setting p-16-24">
|
||||
<h4>{{ $t('login.authentication') }}</h4>
|
||||
<el-tabs v-model="activeName" class="demo-tabs" @tab-click="handleClick">
|
||||
<template v-for="(item, index) in tabList" :key="index">
|
||||
|
|
@ -38,7 +38,7 @@ onMounted(() => {})
|
|||
background-color: var(--app-view-bg-color);
|
||||
box-sizing: border-box;
|
||||
min-width: 700px;
|
||||
height: calc(100vh - var(--app-header-height) - var(--app-view-padding) * 2 - 80px);
|
||||
height: calc(100vh - var(--app-header-height) - var(--app-view-padding) * 2 - 70px);
|
||||
box-sizing: border-box;
|
||||
.form-container {
|
||||
width: 70%;
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
<div class="dataset-setting main-calc-height">
|
||||
<el-scrollbar>
|
||||
<div class="p-24" v-loading="loading">
|
||||
<h4 class="title-decoration-1 mb-16">基本信息</h4>
|
||||
<BaseForm ref="BaseFormRef" :data="detail" />
|
||||
|
||||
<el-form
|
||||
|
|
@ -104,7 +105,7 @@ import { useRoute } from 'vue-router'
|
|||
import BaseForm from '@/views/dataset/component/BaseForm.vue'
|
||||
import datasetApi from '@/api/dataset'
|
||||
import type { ApplicationFormType } from '@/api/type/application'
|
||||
import { MsgSuccess } from '@/utils/message'
|
||||
import { MsgSuccess, MsgConfirm } from '@/utils/message'
|
||||
import { isAppIcon } from '@/utils/application'
|
||||
import useStore from '@/stores'
|
||||
const route = useRoute()
|
||||
|
|
@ -119,6 +120,8 @@ const loading = ref(false)
|
|||
const detail = ref<any>({})
|
||||
const application_list = ref<Array<ApplicationFormType>>([])
|
||||
const application_id_list = ref([])
|
||||
const cloneModelId = ref('')
|
||||
|
||||
const form = ref<any>({
|
||||
source_url: '',
|
||||
selector: ''
|
||||
|
|
@ -132,7 +135,6 @@ async function submit() {
|
|||
if (await BaseFormRef.value?.validate()) {
|
||||
await webFormRef.value.validate((valid: any) => {
|
||||
if (valid) {
|
||||
loading.value = true
|
||||
const obj =
|
||||
detail.value.type === '1'
|
||||
? {
|
||||
|
|
@ -144,15 +146,25 @@ async function submit() {
|
|||
application_id_list: application_id_list.value,
|
||||
...BaseFormRef.value.form
|
||||
}
|
||||
datasetApi
|
||||
.putDataset(id, obj)
|
||||
.then((res) => {
|
||||
|
||||
if (cloneModelId.value !== BaseFormRef.value.form.embedding_mode_id) {
|
||||
MsgConfirm(`提示`, `修改知识库向量模型后,需要对知识库重新向量化,是否继续保存?`, {
|
||||
confirmButtonText: '重新向量化',
|
||||
confirmButtonClass: 'primary'
|
||||
})
|
||||
.then(() => {
|
||||
datasetApi.putDataset(id, obj, loading).then((res) => {
|
||||
datasetApi.putReEmbeddingDataset(id).then(() => {
|
||||
MsgSuccess('保存成功')
|
||||
})
|
||||
})
|
||||
})
|
||||
.catch(() => {})
|
||||
} else {
|
||||
datasetApi.putDataset(id, obj, loading).then((res) => {
|
||||
MsgSuccess('保存成功')
|
||||
loading.value = false
|
||||
})
|
||||
.catch(() => {
|
||||
loading.value = false
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
@ -161,10 +173,10 @@ async function submit() {
|
|||
function getDetail() {
|
||||
dataset.asyncGetDatasetDetail(id, loading).then((res: any) => {
|
||||
detail.value = res.data
|
||||
cloneModelId.value = res.data?.embedding_mode_id
|
||||
if (detail.value.type === '1') {
|
||||
form.value = res.data.meta
|
||||
}
|
||||
|
||||
application_id_list.value = res.data?.application_id_list
|
||||
datasetApi.listUsableApplication(id, loading).then((ok) => {
|
||||
application_list.value = ok.data
|
||||
|
|
|
|||
|
|
@ -1,15 +1,18 @@
|
|||
<template>
|
||||
<LayoutContainer :header="isCreate ? '创建知识库' : '上传文档'" class="create-dataset">
|
||||
<LayoutContainer header="上传文档" class="create-dataset">
|
||||
<template #backButton>
|
||||
<back-button @click="back"></back-button>
|
||||
</template>
|
||||
<div class="create-dataset__main flex" v-loading="loading">
|
||||
<div class="create-dataset__component main-calc-height">
|
||||
<template v-if="active === 0">
|
||||
<StepFirst ref="StepFirstRef" />
|
||||
<div class="upload-document p-24">
|
||||
<!-- 上传文档 -->
|
||||
<UploadComponent ref="UploadComponentRef" />
|
||||
</div>
|
||||
</template>
|
||||
<template v-else-if="active === 1">
|
||||
<StepSecond ref="StepSecondRef" />
|
||||
<SetRules ref="SetRulesRef" />
|
||||
</template>
|
||||
<template v-else-if="active === 2">
|
||||
<ResultSuccess :data="successInfo" />
|
||||
|
|
@ -19,12 +22,7 @@
|
|||
<div class="create-dataset__footer text-right border-t" v-if="active !== 2">
|
||||
<el-button @click="router.go(-1)" :disabled="loading">取消</el-button>
|
||||
<el-button @click="prev" v-if="active === 1" :disabled="loading">上一步</el-button>
|
||||
<el-button
|
||||
@click="next"
|
||||
type="primary"
|
||||
v-if="active === 0"
|
||||
:disabled="loading || StepFirstRef?.loading"
|
||||
>
|
||||
<el-button @click="next" type="primary" v-if="active === 0" :disabled="loading">
|
||||
创建并导入
|
||||
</el-button>
|
||||
<el-button @click="submit" type="primary" v-if="active === 1" :disabled="loading">
|
||||
|
|
@ -36,9 +34,9 @@
|
|||
<script setup lang="ts">
|
||||
import { ref, computed, onUnmounted } from 'vue'
|
||||
import { useRouter, useRoute } from 'vue-router'
|
||||
import StepFirst from './step/StepFirst.vue'
|
||||
import StepSecond from './step/StepSecond.vue'
|
||||
import ResultSuccess from './step/ResultSuccess.vue'
|
||||
import SetRules from './component/SetRules.vue'
|
||||
import ResultSuccess from './component/ResultSuccess.vue'
|
||||
import UploadComponent from './component/UploadComponent.vue'
|
||||
import datasetApi from '@/api/dataset'
|
||||
import documentApi from '@/api/document'
|
||||
import type { datasetData } from '@/api/type/dataset'
|
||||
|
|
@ -46,33 +44,17 @@ import { MsgConfirm, MsgSuccess } from '@/utils/message'
|
|||
|
||||
import useStore from '@/stores'
|
||||
const { dataset, document } = useStore()
|
||||
const baseInfo = computed(() => dataset.baseInfo)
|
||||
const webInfo = computed(() => dataset.webInfo)
|
||||
const documentsFiles = computed(() => dataset.documentsFiles)
|
||||
const documentsType = computed(() => dataset.documentsType)
|
||||
|
||||
const router = useRouter()
|
||||
const route = useRoute()
|
||||
const {
|
||||
params: { type },
|
||||
query: { id } // id为datasetID,有id的是上传文档
|
||||
} = route
|
||||
const isCreate = type === 'create'
|
||||
// const steps = [
|
||||
// {
|
||||
// ref: 'StepFirstRef',
|
||||
// name: '上传文档',
|
||||
// component: StepFirst
|
||||
// },
|
||||
// {
|
||||
// ref: 'StepSecondRef',
|
||||
// name: '设置分段规则',
|
||||
// component: StepSecond
|
||||
// }
|
||||
// ]
|
||||
|
||||
const StepFirstRef = ref()
|
||||
const StepSecondRef = ref()
|
||||
const SetRulesRef = ref()
|
||||
const UploadComponentRef = ref()
|
||||
|
||||
const loading = ref(false)
|
||||
const disabled = ref(false)
|
||||
|
|
@ -81,7 +63,7 @@ const successInfo = ref<any>(null)
|
|||
|
||||
async function next() {
|
||||
disabled.value = true
|
||||
if (await StepFirstRef.value?.onSubmit()) {
|
||||
if (await UploadComponentRef.value.validate()) {
|
||||
if (documentsType.value === 'QA') {
|
||||
let fd = new FormData()
|
||||
documentsFiles.value.forEach((item: any) => {
|
||||
|
|
@ -118,16 +100,14 @@ const prev = () => {
|
|||
}
|
||||
|
||||
function clearStore() {
|
||||
dataset.saveBaseInfo(null)
|
||||
dataset.saveWebInfo(null)
|
||||
dataset.saveDocumentsFile([])
|
||||
dataset.saveDocumentsType('')
|
||||
}
|
||||
function submit() {
|
||||
loading.value = true
|
||||
const documents = [] as any
|
||||
StepSecondRef.value?.paragraphList.map((item: any) => {
|
||||
if (!StepSecondRef.value?.checkedConnect) {
|
||||
SetRulesRef.value?.paragraphList.map((item: any) => {
|
||||
if (!SetRulesRef.value?.checkedConnect) {
|
||||
item.content.map((v: any) => {
|
||||
delete v['problem_list']
|
||||
})
|
||||
|
|
@ -159,7 +139,7 @@ function submit() {
|
|||
}
|
||||
}
|
||||
function back() {
|
||||
if (baseInfo.value || webInfo.value || documentsFiles.value?.length > 0) {
|
||||
if (documentsFiles.value?.length > 0) {
|
||||
MsgConfirm(`提示`, `当前的更改尚未保存,确认退出吗?`, {
|
||||
confirmButtonText: '确认',
|
||||
type: 'warning'
|
||||
|
|
@ -206,5 +186,10 @@ onUnmounted(() => {
|
|||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
.upload-document {
|
||||
width: 70%;
|
||||
margin: 0 auto;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
<template>
|
||||
<h4 class="title-decoration-1 mb-16">基本信息</h4>
|
||||
<el-form
|
||||
ref="FormRef"
|
||||
:model="form"
|
||||
:rules="rules"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
v-loading="loading"
|
||||
>
|
||||
<el-form-item label="知识库名称" prop="name">
|
||||
<el-input
|
||||
|
|
@ -27,14 +27,72 @@
|
|||
@blur="form.desc = form.desc.trim()"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="Embedding模型" prop="embedding_mode_id">
|
||||
<el-select
|
||||
v-model="form.embedding_mode_id"
|
||||
placeholder="请选择Embedding模型"
|
||||
class="w-full"
|
||||
popper-class="select-model"
|
||||
:clearable="true"
|
||||
>
|
||||
<el-option-group
|
||||
v-for="(value, label) in modelOptions"
|
||||
:key="value"
|
||||
:label="relatedObject(providerOptions, label, 'provider')?.name"
|
||||
>
|
||||
<el-option
|
||||
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
class="flex-between"
|
||||
>
|
||||
<div class="flex">
|
||||
<span
|
||||
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||
class="model-icon mr-8"
|
||||
></span>
|
||||
<span>{{ item.name }}</span>
|
||||
</div>
|
||||
<el-icon class="check-icon" v-if="item.id === form.embedding_mode_id"
|
||||
><Check
|
||||
/></el-icon>
|
||||
</el-option>
|
||||
<!-- 不可用 -->
|
||||
<el-option
|
||||
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
class="flex-between"
|
||||
disabled
|
||||
>
|
||||
<div class="flex">
|
||||
<span
|
||||
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||
class="model-icon mr-8"
|
||||
></span>
|
||||
<span>{{ item.name }}</span>
|
||||
<span class="danger">{{
|
||||
$t('views.application.applicationForm.form.aiModel.unavailable')
|
||||
}}</span>
|
||||
</div>
|
||||
<el-icon class="check-icon" v-if="item.id === form.embedding_mode_id"
|
||||
><Check
|
||||
/></el-icon>
|
||||
</el-option>
|
||||
</el-option-group>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, onMounted, onUnmounted, computed, watch } from 'vue'
|
||||
import { useRoute } from 'vue-router'
|
||||
import { groupBy } from 'lodash'
|
||||
import useStore from '@/stores'
|
||||
import type { datasetData } from '@/api/type/dataset'
|
||||
import { isAllPropertiesEmpty } from '@/utils/utils'
|
||||
import { relatedObject } from '@/utils/utils'
|
||||
import type { Provider } from '@/api/type/model'
|
||||
|
||||
const props = defineProps({
|
||||
data: {
|
||||
|
|
@ -42,23 +100,23 @@ const props = defineProps({
|
|||
default: () => {}
|
||||
}
|
||||
})
|
||||
const route = useRoute()
|
||||
const {
|
||||
params: { type }
|
||||
} = route
|
||||
const isCreate = type === 'create'
|
||||
const { dataset } = useStore()
|
||||
const baseInfo = computed(() => dataset.baseInfo)
|
||||
const { model } = useStore()
|
||||
const form = ref<datasetData>({
|
||||
name: '',
|
||||
desc: ''
|
||||
desc: '',
|
||||
embedding_mode_id: ''
|
||||
})
|
||||
|
||||
const rules = reactive({
|
||||
name: [{ required: true, message: '请输入知识库名称', trigger: 'blur' }],
|
||||
desc: [{ required: true, message: '请输入知识库描述', trigger: 'blur' }]
|
||||
desc: [{ required: true, message: '请输入知识库描述', trigger: 'blur' }],
|
||||
embedding_mode_id: [{ required: true, message: '请输入Embedding模型', trigger: 'change' }]
|
||||
})
|
||||
|
||||
const FormRef = ref()
|
||||
const loading = ref(false)
|
||||
const modelOptions = ref<any>([])
|
||||
const providerOptions = ref<Array<Provider>>([])
|
||||
|
||||
watch(
|
||||
() => props.data,
|
||||
|
|
@ -66,23 +124,13 @@ watch(
|
|||
if (value && JSON.stringify(value) !== '{}') {
|
||||
form.value.name = value.name
|
||||
form.value.desc = value.desc
|
||||
form.value.embedding_mode_id = value.embedding_mode_id
|
||||
}
|
||||
},
|
||||
{
|
||||
immediate: true
|
||||
}
|
||||
)
|
||||
|
||||
watch(form.value, (value) => {
|
||||
if (isAllPropertiesEmpty(value)) {
|
||||
dataset.saveBaseInfo(null)
|
||||
} else {
|
||||
if (isCreate) {
|
||||
dataset.saveBaseInfo(value)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
/*
|
||||
表单校验
|
||||
*/
|
||||
|
|
@ -93,16 +141,43 @@ function validate() {
|
|||
})
|
||||
}
|
||||
|
||||
function getModel() {
|
||||
loading.value = true
|
||||
model
|
||||
.asyncGetModel({ model_type: 'EMBEDDING' })
|
||||
.then((res: any) => {
|
||||
modelOptions.value = groupBy(res?.data, 'provider')
|
||||
loading.value = false
|
||||
})
|
||||
.catch(() => {
|
||||
loading.value = false
|
||||
})
|
||||
}
|
||||
|
||||
function getProvider() {
|
||||
loading.value = true
|
||||
model
|
||||
.asyncGetProvider()
|
||||
.then((res: any) => {
|
||||
providerOptions.value = res?.data
|
||||
loading.value = false
|
||||
})
|
||||
.catch(() => {
|
||||
loading.value = false
|
||||
})
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
if (baseInfo.value) {
|
||||
form.value = baseInfo.value
|
||||
}
|
||||
getProvider()
|
||||
getModel()
|
||||
})
|
||||
onUnmounted(() => {
|
||||
form.value = {
|
||||
name: '',
|
||||
desc: ''
|
||||
desc: '',
|
||||
embedding_mode_id: ''
|
||||
}
|
||||
FormRef.value?.clearValidate()
|
||||
})
|
||||
|
||||
defineExpose({
|
||||
|
|
|
|||
|
|
@ -0,0 +1,177 @@
|
|||
<template>
|
||||
<el-dialog title="创建知识库" v-model="dialogVisible" width="650" append-to-body>
|
||||
<!-- 基本信息 -->
|
||||
<BaseForm ref="BaseFormRef" v-if="dialogVisible" />
|
||||
<el-form
|
||||
ref="DatasetFormRef"
|
||||
:rules="rules"
|
||||
:model="datasetForm"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
>
|
||||
<el-form-item label="知识库类型" required>
|
||||
<el-radio-group v-model="datasetForm.type" class="card__radio" @change="radioChange">
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="12">
|
||||
<el-card
|
||||
shadow="never"
|
||||
class="mb-16"
|
||||
:class="datasetForm.type === '0' ? 'active' : ''"
|
||||
>
|
||||
<el-radio value="0" size="large">
|
||||
<div class="flex align-center">
|
||||
<AppAvatar class="mr-8 avatar-blue" shape="square" :size="32">
|
||||
<img src="@/assets/icon_document.svg" style="width: 58%" alt="" />
|
||||
</AppAvatar>
|
||||
<div>
|
||||
<p class="mb-4">通用型</p>
|
||||
<el-text type="info">可以通过上传文件或手动录入方式构建知识库</el-text>
|
||||
</div>
|
||||
</div>
|
||||
</el-radio>
|
||||
</el-card>
|
||||
</el-col>
|
||||
<el-col :span="12">
|
||||
<el-card
|
||||
shadow="never"
|
||||
class="mb-16"
|
||||
:class="datasetForm.type === '1' ? 'active' : ''"
|
||||
>
|
||||
<el-radio value="1" size="large">
|
||||
<div class="flex align-center">
|
||||
<AppAvatar class="mr-8 avatar-purple" shape="square" :size="32">
|
||||
<img src="@/assets/icon_web.svg" style="width: 58%" alt="" />
|
||||
</AppAvatar>
|
||||
<div>
|
||||
<p class="mb-4">Web 站点</p>
|
||||
<el-text type="info">通过网站链接同步方式构建知识库 </el-text>
|
||||
</div>
|
||||
</div>
|
||||
</el-radio>
|
||||
</el-card>
|
||||
</el-col>
|
||||
</el-row>
|
||||
</el-radio-group>
|
||||
</el-form-item>
|
||||
<el-form-item label="Web 根地址" prop="source_url" v-if="datasetForm.type === '1'">
|
||||
<el-input
|
||||
v-model="datasetForm.source_url"
|
||||
placeholder="请输入 Web 根地址"
|
||||
@blur="datasetForm.source_url = datasetForm.source_url.trim()"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="选择器" v-if="datasetForm.type === '1'">
|
||||
<el-input
|
||||
v-model="datasetForm.selector"
|
||||
placeholder="默认为 body,可输入 .classname/#idname/tagname"
|
||||
@blur="datasetForm.selector = datasetForm.selector.trim()"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
<template #footer>
|
||||
<span class="dialog-footer">
|
||||
<el-button @click.prevent="dialogVisible = false" :loading="loading">
|
||||
{{ $t('views.application.applicationForm.buttons.cancel') }}
|
||||
</el-button>
|
||||
<el-button type="primary" @click="submitValid" :loading="loading">
|
||||
{{ $t('views.application.applicationForm.buttons.create') }}
|
||||
</el-button>
|
||||
</span>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, reactive } from 'vue'
|
||||
import { useRouter, useRoute } from 'vue-router'
|
||||
import BaseForm from './BaseForm.vue'
|
||||
import datasetApi from '@/api/dataset'
|
||||
import { MsgSuccess, MsgAlert } from '@/utils/message'
|
||||
import useStore from '@/stores'
|
||||
import { ValidType, ValidCount } from '@/enums/common'
|
||||
|
||||
const emit = defineEmits(['refresh'])
|
||||
|
||||
const { common, user } = useStore()
|
||||
const router = useRouter()
|
||||
const BaseFormRef = ref()
|
||||
const DatasetFormRef = ref()
|
||||
|
||||
const loading = ref(false)
|
||||
const dialogVisible = ref<boolean>(false)
|
||||
|
||||
const datasetForm = ref<any>({
|
||||
type: '0',
|
||||
source_url: '',
|
||||
selector: ''
|
||||
})
|
||||
|
||||
const rules = reactive({
|
||||
source_url: [{ required: true, message: '请输入 Web 根地址', trigger: 'blur' }]
|
||||
})
|
||||
|
||||
watch(dialogVisible, (bool) => {
|
||||
if (!bool) {
|
||||
datasetForm.value = {
|
||||
type: '0',
|
||||
source_url: '',
|
||||
selector: ''
|
||||
}
|
||||
DatasetFormRef.value?.clearValidate()
|
||||
}
|
||||
})
|
||||
|
||||
const open = () => {
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
const submitValid = () => {
|
||||
if (user.isEnterprise()) {
|
||||
submitHandle()
|
||||
} else {
|
||||
common.asyncGetValid(ValidType.Dataset, ValidCount.Dataset, loading).then(async (res: any) => {
|
||||
if (res?.data) {
|
||||
submitHandle()
|
||||
} else {
|
||||
MsgAlert('提示', '社区版最多支持 50 个知识库,如需拥有更多知识库,请升级为专业版。')
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
const submitHandle = async () => {
|
||||
if (await BaseFormRef.value?.validate()) {
|
||||
await DatasetFormRef.value.validate((valid: any) => {
|
||||
if (valid) {
|
||||
if (datasetForm.value.type === '0') {
|
||||
const obj = {
|
||||
...BaseFormRef.value.form,
|
||||
type: datasetForm.value.type
|
||||
}
|
||||
datasetApi.postDataset(obj, loading).then((res) => {
|
||||
MsgSuccess('创建成功')
|
||||
router.push({ path: `/dataset/${res.data.id}/document` })
|
||||
emit('refresh')
|
||||
})
|
||||
} else {
|
||||
const obj = { ...BaseFormRef.value.form, ...datasetForm.value }
|
||||
datasetApi.postWebDataset(obj, loading).then((res) => {
|
||||
MsgSuccess('创建成功')
|
||||
router.push({ path: `/dataset/${res.data.id}/document` })
|
||||
emit('refresh')
|
||||
})
|
||||
}
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
})
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
function radioChange() {
|
||||
datasetForm.value.source_url = ''
|
||||
datasetForm.value.selector = ''
|
||||
}
|
||||
|
||||
defineExpose({ open })
|
||||
</script>
|
||||
<style lang="scss" scope></style>
|
||||
|
|
@ -65,7 +65,7 @@
|
|||
show-input
|
||||
:show-input-controls="false"
|
||||
:min="50"
|
||||
:max="4096"
|
||||
:max="100000"
|
||||
/>
|
||||
</div>
|
||||
<div class="form-item mb-16">
|
||||
|
|
@ -22,7 +22,7 @@
|
|||
>
|
||||
<el-row :gutter="15">
|
||||
<el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16">
|
||||
<CardAdd title="创建知识库" @click="router.push({ path: '/dataset/create' })" />
|
||||
<CardAdd title="创建知识库" @click="openCreateDialog" />
|
||||
</el-col>
|
||||
<template v-for="(item, index) in datasetList" :key="index">
|
||||
<el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16">
|
||||
|
|
@ -107,17 +107,20 @@
|
|||
</InfiniteScroll>
|
||||
</div>
|
||||
<SyncWebDialog ref="SyncWebDialogRef" @refresh="refresh" />
|
||||
<CreateDatasetDialog ref="CreateDatasetDialogRef"/>
|
||||
</div>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, reactive, computed } from 'vue'
|
||||
import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue'
|
||||
import CreateDatasetDialog from './component/CreateDatasetDialog.vue'
|
||||
import datasetApi from '@/api/dataset'
|
||||
import { MsgSuccess, MsgConfirm } from '@/utils/message'
|
||||
import { useRouter } from 'vue-router'
|
||||
import { numberFormat } from '@/utils/utils'
|
||||
const router = useRouter()
|
||||
|
||||
const CreateDatasetDialogRef = ref()
|
||||
const SyncWebDialogRef = ref()
|
||||
const loading = ref(false)
|
||||
const datasetList = ref<any[]>([])
|
||||
|
|
@ -129,6 +132,10 @@ const paginationConfig = reactive({
|
|||
|
||||
const searchValue = ref('')
|
||||
|
||||
function openCreateDialog() {
|
||||
CreateDatasetDialogRef.value.open()
|
||||
}
|
||||
|
||||
function refresh() {
|
||||
MsgSuccess('同步任务发送成功')
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,183 +0,0 @@
|
|||
<template>
|
||||
<el-scrollbar>
|
||||
<div class="upload-document p-24">
|
||||
<!-- 基本信息 -->
|
||||
<BaseForm ref="BaseFormRef" v-if="isCreate" />
|
||||
<el-form
|
||||
v-if="isCreate"
|
||||
ref="webFormRef"
|
||||
:rules="rules"
|
||||
:model="form"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
>
|
||||
<el-form-item label="知识库类型" required>
|
||||
<el-radio-group v-model="form.type" class="card__radio" @change="radioChange">
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="12">
|
||||
<el-card shadow="never" class="mb-16" :class="form.type === '0' ? 'active' : ''">
|
||||
<el-radio value="0" size="large">
|
||||
<div class="flex align-center">
|
||||
<AppAvatar class="mr-8 avatar-blue" shape="square" :size="32">
|
||||
<img src="@/assets/icon_document.svg" style="width: 58%" alt="" />
|
||||
</AppAvatar>
|
||||
<div>
|
||||
<p class="mb-4">通用型</p>
|
||||
<el-text type="info">可以通过上传文件或手动录入方式构建知识库</el-text>
|
||||
</div>
|
||||
</div>
|
||||
</el-radio>
|
||||
</el-card>
|
||||
</el-col>
|
||||
<el-col :span="12">
|
||||
<el-card shadow="never" class="mb-16" :class="form.type === '1' ? 'active' : ''">
|
||||
<el-radio value="1" size="large">
|
||||
<div class="flex align-center">
|
||||
<AppAvatar class="mr-8 avatar-purple" shape="square" :size="32">
|
||||
<img src="@/assets/icon_web.svg" style="width: 58%" alt="" />
|
||||
</AppAvatar>
|
||||
<div>
|
||||
<p class="mb-4">Web 站点</p>
|
||||
<el-text type="info">通过网站链接同步方式构建知识库 </el-text>
|
||||
</div>
|
||||
</div>
|
||||
</el-radio>
|
||||
</el-card>
|
||||
</el-col>
|
||||
</el-row>
|
||||
</el-radio-group>
|
||||
</el-form-item>
|
||||
<el-form-item label="Web 根地址" prop="source_url" v-if="form.type === '1'">
|
||||
<el-input
|
||||
v-model="form.source_url"
|
||||
placeholder="请输入 Web 根地址"
|
||||
@blur="form.source_url = form.source_url.trim()"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="选择器" v-if="form.type === '1'">
|
||||
<el-input
|
||||
v-model="form.selector"
|
||||
placeholder="默认为 body,可输入 .classname/#idname/tagname"
|
||||
@blur="form.selector = form.selector.trim()"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
|
||||
<!-- 上传文档 -->
|
||||
<UploadComponent ref="UploadComponentRef" v-if="form.type === '0'" />
|
||||
</div>
|
||||
</el-scrollbar>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, reactive, watch } from 'vue'
|
||||
import { useRouter, useRoute } from 'vue-router'
|
||||
import BaseForm from '@/views/dataset/component/BaseForm.vue'
|
||||
import UploadComponent from '@/views/dataset/component/UploadComponent.vue'
|
||||
import { isAllPropertiesEmpty } from '@/utils/utils'
|
||||
import datasetApi from '@/api/dataset'
|
||||
import { MsgError, MsgSuccess } from '@/utils/message'
|
||||
import useStore from '@/stores'
|
||||
const { dataset } = useStore()
|
||||
|
||||
const route = useRoute()
|
||||
const router = useRouter()
|
||||
const {
|
||||
params: { type }
|
||||
} = route
|
||||
const isCreate = type === 'create'
|
||||
const BaseFormRef = ref()
|
||||
const UploadComponentRef = ref()
|
||||
const webFormRef = ref()
|
||||
const loading = ref(false)
|
||||
|
||||
const form = ref<any>({
|
||||
type: '0',
|
||||
source_url: '',
|
||||
selector: ''
|
||||
})
|
||||
|
||||
const rules = reactive({
|
||||
source_url: [{ required: true, message: '请输入 Web 根地址', trigger: 'blur' }]
|
||||
})
|
||||
|
||||
watch(form.value, (value) => {
|
||||
if (isAllPropertiesEmpty(value)) {
|
||||
dataset.saveWebInfo(null)
|
||||
} else {
|
||||
dataset.saveWebInfo(value)
|
||||
}
|
||||
})
|
||||
|
||||
function radioChange() {
|
||||
dataset.saveDocumentsFile([])
|
||||
dataset.saveDocumentsType('')
|
||||
form.value.source_url = ''
|
||||
form.value.selector = ''
|
||||
}
|
||||
|
||||
const onSubmit = async () => {
|
||||
if (isCreate) {
|
||||
if (form.value.type === '0') {
|
||||
if ((await BaseFormRef.value?.validate()) && (await UploadComponentRef.value.validate())) {
|
||||
if (UploadComponentRef.value.form.fileList.length > 50) {
|
||||
MsgError('每次最多上传50个文件!')
|
||||
return false
|
||||
} else {
|
||||
/*
|
||||
stores保存数据
|
||||
*/
|
||||
dataset.saveBaseInfo(BaseFormRef.value.form)
|
||||
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
|
||||
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if (await BaseFormRef.value?.validate()) {
|
||||
await webFormRef.value.validate((valid: any) => {
|
||||
if (valid) {
|
||||
const obj = { ...BaseFormRef.value.form, ...form.value }
|
||||
datasetApi.postWebDataset(obj, loading).then((res) => {
|
||||
MsgSuccess('提交成功')
|
||||
dataset.saveBaseInfo(null)
|
||||
dataset.saveWebInfo(null)
|
||||
router.push({ path: `/dataset/${res.data.id}/document` })
|
||||
})
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
})
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (await UploadComponentRef.value.validate()) {
|
||||
/*
|
||||
stores保存数据
|
||||
*/
|
||||
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
|
||||
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
onMounted(() => {})
|
||||
|
||||
defineExpose({
|
||||
onSubmit,
|
||||
loading
|
||||
})
|
||||
</script>
|
||||
<style scoped lang="scss">
|
||||
.upload-document {
|
||||
width: 70%;
|
||||
margin: 0 auto;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
</style>
|
||||
|
|
@ -141,4 +141,13 @@ const submit = async (formEl: FormInstance) => {
|
|||
|
||||
defineExpose({ open })
|
||||
</script>
|
||||
<style lang="scss" scope></style>
|
||||
<style lang="scss" scope>
|
||||
.edit-mark-dialog {
|
||||
.el-dialog__header.show-close {
|
||||
padding-right: 15px;
|
||||
}
|
||||
.el-dialog__headerbtn {
|
||||
top: 13px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
v-if="isEdit"
|
||||
v-model="form.content"
|
||||
placeholder="请输入分段内容"
|
||||
:maxLength="4096"
|
||||
:maxLength="100000"
|
||||
:preview="false"
|
||||
:toolbars="toolbars"
|
||||
style="height: 300px"
|
||||
|
|
@ -31,7 +31,7 @@
|
|||
:footers="footers"
|
||||
>
|
||||
<template #defFooters>
|
||||
<span style="margin-left: -6px">/ 4096</span>
|
||||
<span style="margin-left: -6px">/ 100000</span>
|
||||
</template>
|
||||
</MdEditor>
|
||||
<MdPreview
|
||||
|
|
|
|||
|
|
@ -55,6 +55,32 @@
|
|||
placeholder="请给基础模型设置一个名称"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item prop="model_type" :rules="base_form_data_rule.permission_type">
|
||||
<template #label>
|
||||
<span>权限</span>
|
||||
</template>
|
||||
|
||||
<el-radio-group v-model="base_form_data.permission_type" class="card__radio">
|
||||
<el-row :gutter="16">
|
||||
<template v-for="(value, key) of PermissionType" :key="key">
|
||||
<el-col :span="12">
|
||||
<el-card
|
||||
shadow="never"
|
||||
class="mb-16"
|
||||
:class="base_form_data.permission_type === key ? 'active' : ''"
|
||||
>
|
||||
<el-radio :value="key" size="large">
|
||||
<p class="mb-4">{{ value }}</p>
|
||||
<el-text type="info">
|
||||
{{ PermissionDesc[key] }}
|
||||
</el-text>
|
||||
</el-radio>
|
||||
</el-card>
|
||||
</el-col>
|
||||
</template>
|
||||
</el-row>
|
||||
</el-radio-group>
|
||||
</el-form-item>
|
||||
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
|
||||
<template #label>
|
||||
<span>模型类型</span>
|
||||
|
|
@ -74,6 +100,7 @@
|
|||
></el-option>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
|
||||
<el-form-item prop="model_name" :rules="base_form_data_rule.model_name">
|
||||
<template #label>
|
||||
<div class="flex align-center" style="display: inline-flex">
|
||||
|
|
@ -135,6 +162,7 @@ import type { FormField } from '@/components/dynamics-form/type'
|
|||
import DynamicsForm from '@/components/dynamics-form/index.vue'
|
||||
import type { FormRules } from 'element-plus'
|
||||
import { MsgSuccess } from '@/utils/message'
|
||||
import { PermissionType, PermissionDesc } from '@/enums/model'
|
||||
|
||||
const providerValue = ref<Provider>()
|
||||
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
|
||||
|
|
@ -150,17 +178,18 @@ const dialogVisible = ref<boolean>(false)
|
|||
|
||||
const base_form_data_rule = ref<FormRules>({
|
||||
name: { required: true, trigger: 'blur', message: '模型名不能为空' },
|
||||
permission_type: { required: true, trigger: 'change', message: '权限不能为空' },
|
||||
model_type: { required: true, trigger: 'change', message: '模型类型不能为空' },
|
||||
model_name: { required: true, trigger: 'change', message: '基础模型不能为空' }
|
||||
})
|
||||
|
||||
const base_form_data = ref<{
|
||||
name: string
|
||||
|
||||
permission_type: string
|
||||
model_type: string
|
||||
|
||||
model_name: string
|
||||
}>({ name: '', model_type: '', model_name: '' })
|
||||
}>({ name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' })
|
||||
|
||||
const credential_form_data = ref<Dict<any>>({})
|
||||
|
||||
|
|
@ -212,7 +241,7 @@ const list_base_model = (model_type: any) => {
|
|||
}
|
||||
|
||||
const close = () => {
|
||||
base_form_data.value = { name: '', model_type: '', model_name: '' }
|
||||
base_form_data.value = { name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' }
|
||||
credential_form_data.value = {}
|
||||
model_form_field.value = []
|
||||
base_model_list.value = []
|
||||
|
|
|
|||
|
|
@ -48,6 +48,32 @@
|
|||
placeholder="请给基础模型设置一个名称"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item prop="model_type" :rules="base_form_data_rule.permission_type">
|
||||
<template #label>
|
||||
<span>权限</span>
|
||||
</template>
|
||||
|
||||
<el-radio-group v-model="base_form_data.permission_type" class="card__radio">
|
||||
<el-row :gutter="16">
|
||||
<template v-for="(value, key) of PermissionType" :key="key">
|
||||
<el-col :span="12">
|
||||
<el-card
|
||||
shadow="never"
|
||||
class="mb-16"
|
||||
:class="base_form_data.permission_type === key ? 'active' : ''"
|
||||
>
|
||||
<el-radio :value="key" size="large">
|
||||
<p class="mb-4">{{ value }}</p>
|
||||
<el-text type="info">
|
||||
{{ PermissionDesc[key] }}
|
||||
</el-text>
|
||||
</el-radio>
|
||||
</el-card>
|
||||
</el-col>
|
||||
</template>
|
||||
</el-row>
|
||||
</el-radio-group>
|
||||
</el-form-item>
|
||||
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
|
||||
<template #label>
|
||||
<span>模型类型</span>
|
||||
|
|
@ -128,7 +154,7 @@ import type { FormField } from '@/components/dynamics-form/type'
|
|||
import DynamicsForm from '@/components/dynamics-form/index.vue'
|
||||
import type { FormRules } from 'element-plus'
|
||||
import { MsgSuccess } from '@/utils/message'
|
||||
import AppIcon from '@/components/icons/AppIcon.vue'
|
||||
import { PermissionType, PermissionDesc } from '@/enums/model'
|
||||
|
||||
const providerValue = ref<Provider>()
|
||||
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
|
||||
|
|
@ -151,11 +177,11 @@ const base_form_data_rule = ref<FormRules>({
|
|||
|
||||
const base_form_data = ref<{
|
||||
name: string
|
||||
|
||||
permission_type: string
|
||||
model_type: string
|
||||
|
||||
model_name: string
|
||||
}>({ name: '', model_type: '', model_name: '' })
|
||||
}>({ name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' })
|
||||
|
||||
const credential_form_data = ref<Dict<any>>({})
|
||||
|
||||
|
|
@ -204,6 +230,7 @@ const open = (provider: Provider, model: Model) => {
|
|||
|
||||
base_form_data.value = {
|
||||
name: model.name,
|
||||
permission_type: model.permission_type,
|
||||
model_type: model.model_type,
|
||||
model_name: model.model_name
|
||||
}
|
||||
|
|
@ -214,7 +241,7 @@ const open = (provider: Provider, model: Model) => {
|
|||
}
|
||||
|
||||
const close = () => {
|
||||
base_form_data.value = { name: '', model_type: '', model_name: '' }
|
||||
base_form_data.value = { name: '', model_type: '', model_name: '', permission_type: '' }
|
||||
dynamicsFormRef.value?.ruleFormRef?.resetFields()
|
||||
credential_form_data.value = {}
|
||||
model_form_field.value = []
|
||||
|
|
|
|||
|
|
@ -1,16 +1,30 @@
|
|||
<template>
|
||||
<card-box :title="model.name" shadow="hover" class="model-card">
|
||||
<template #header>
|
||||
<div class="flex align-center">
|
||||
<div class="flex">
|
||||
<span style="height: 32px; width: 32px" :innerHTML="icon" class="mr-12"></span>
|
||||
<auto-tooltip :content="model.name" style="max-width: 40%">
|
||||
{{ model.name }}
|
||||
</auto-tooltip>
|
||||
<div class="flex align-center" v-if="currentModel.status === 'ERROR'">
|
||||
<el-tag type="danger" class="ml-8">失败</el-tag>
|
||||
<el-tooltip effect="dark" :content="errMessage" placement="top">
|
||||
<el-icon class="danger ml-4" size="20"><Warning /></el-icon>
|
||||
</el-tooltip>
|
||||
<div class="w-full">
|
||||
<div class="flex" style="height: 22px">
|
||||
<auto-tooltip :content="model.name" style="max-width: 40%">
|
||||
{{ model.name }}
|
||||
</auto-tooltip>
|
||||
<span v-if="currentModel.status === 'ERROR'">
|
||||
<el-tooltip effect="dark" :content="errMessage" placement="top">
|
||||
<el-icon class="danger ml-4" size="18"><Warning /></el-icon>
|
||||
</el-tooltip>
|
||||
</span>
|
||||
<span v-if="currentModel.status === 'PAUSE_DOWNLOAD'">
|
||||
<el-tooltip effect="dark" content="暂停下载" placement="top">
|
||||
<el-icon class="danger ml-4" size="18"><Warning /></el-icon>
|
||||
</el-tooltip>
|
||||
</span>
|
||||
</div>
|
||||
<div class="mt-4">
|
||||
<el-tag v-if="model.permission_type === 'PRIVATE'" type="danger" class="danger-tag"
|
||||
>私有</el-tag
|
||||
>
|
||||
<el-tag v-else type="info" class="info-tag">公有</el-tag>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
|
@ -29,18 +43,14 @@
|
|||
</div>
|
||||
<!-- progress -->
|
||||
<div class="progress-mask" v-if="currentModel.status === 'DOWNLOAD'">
|
||||
<el-progress
|
||||
type="circle"
|
||||
:width="56"
|
||||
color="#3370FF"
|
||||
:percentage="progress"
|
||||
class="percentage"
|
||||
>
|
||||
<template #default="{ percentage }">
|
||||
<span class="percentage-value">{{ percentage }}%</span>
|
||||
</template>
|
||||
</el-progress>
|
||||
<span class="percentage-label">正在下载 <span class="dotting"></span></span>
|
||||
<DownloadLoading class="percentage" />
|
||||
|
||||
<div class="percentage-label flex-center">
|
||||
正在下载中 <span class="dotting"></span>
|
||||
<el-button link type="primary" class="ml-16" @click.stop="cancelDownload"
|
||||
>取消下载</el-button
|
||||
>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<template #mouseEnter>
|
||||
|
|
@ -48,7 +58,13 @@
|
|||
<el-tooltip effect="dark" content="修改" placement="top">
|
||||
<el-button text @click.stop="openEditModel">
|
||||
<el-icon>
|
||||
<component :is="currentModel.status === 'ERROR' ? 'RefreshRight' : 'EditPen'" />
|
||||
<component
|
||||
:is="
|
||||
currentModel.status === 'ERROR' || currentModel.status === 'PAUSE_DOWNLOAD'
|
||||
? 'RefreshRight'
|
||||
: 'EditPen'
|
||||
"
|
||||
/>
|
||||
</el-icon>
|
||||
</el-button>
|
||||
</el-tooltip>
|
||||
|
|
@ -68,6 +84,7 @@ import type { Provider, Model } from '@/api/type/model'
|
|||
import ModelApi from '@/api/model'
|
||||
import { computed, ref, onMounted, onBeforeUnmount } from 'vue'
|
||||
import EditModel from '@/views/template/component/EditModel.vue'
|
||||
import DownloadLoading from '@/components/loading/DownloadLoading.vue'
|
||||
import { MsgConfirm } from '@/utils/message'
|
||||
|
||||
const props = defineProps<{
|
||||
|
|
@ -94,27 +111,6 @@ const errMessage = computed(() => {
|
|||
}
|
||||
return ''
|
||||
})
|
||||
const progress = computed(() => {
|
||||
if (currentModel.value) {
|
||||
const down_model_chunk = currentModel.value.meta['down_model_chunk']
|
||||
if (down_model_chunk) {
|
||||
const maxObj = down_model_chunk
|
||||
.filter((chunk: any) => chunk.index > 1)
|
||||
.reduce(
|
||||
(prev: any, current: any) => {
|
||||
return (prev.index || 0) > (current.index || 0) ? prev : current
|
||||
},
|
||||
{ progress: 0 }
|
||||
)
|
||||
if (maxObj) {
|
||||
return parseFloat(maxObj.progress?.toFixed(1))
|
||||
}
|
||||
return 0
|
||||
}
|
||||
return 0
|
||||
}
|
||||
return 0
|
||||
})
|
||||
const emit = defineEmits(['change', 'update:model'])
|
||||
const eidtModelRef = ref<InstanceType<typeof EditModel>>()
|
||||
let interval: any
|
||||
|
|
@ -130,6 +126,13 @@ const deleteModel = () => {
|
|||
})
|
||||
.catch(() => {})
|
||||
}
|
||||
|
||||
const cancelDownload = () => {
|
||||
ModelApi.pauseDownload(props.model.id).then(() => {
|
||||
downModel.value = undefined
|
||||
emit('change')
|
||||
})
|
||||
}
|
||||
const openEditModel = () => {
|
||||
const provider = props.provider_list.find((p) => p.provider === props.model.provider)
|
||||
if (provider) {
|
||||
|
|
@ -197,21 +200,21 @@ onBeforeUnmount(() => {
|
|||
z-index: 99;
|
||||
text-align: center;
|
||||
.percentage {
|
||||
top: 50%;
|
||||
transform: translateY(-65%);
|
||||
margin-top: 55px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.percentage-value {
|
||||
display: block;
|
||||
font-size: 12px;
|
||||
color: var(--el-color-primary);
|
||||
}
|
||||
// .percentage-value {
|
||||
// display: flex;
|
||||
// font-size: 13px;
|
||||
// align-items: center;
|
||||
// color: var(--app-text-color-secondary);
|
||||
// }
|
||||
.percentage-label {
|
||||
display: block;
|
||||
margin-top: 45px;
|
||||
margin-top: 50px;
|
||||
margin-left: 10px;
|
||||
font-size: 12px;
|
||||
color: var(--el-color-primary);
|
||||
font-size: 13px;
|
||||
color: var(--app-text-color-secondary);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -145,10 +145,7 @@ function createUser() {
|
|||
title.value = '创建用户'
|
||||
UserDialogRef.value.open()
|
||||
} else {
|
||||
MsgAlert(
|
||||
'提示',
|
||||
'社区版最多支持 2 个用户,如需拥有更多用户,请联系我们(https://fit2cloud.com/)。'
|
||||
)
|
||||
MsgAlert('提示', '社区版最多支持 2 个用户,如需拥有更多用户,请升级为专业版。')
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue