feat: 模型管理支持向量模型,知识库可以关联向量模型

feat:  模型管理支持向量模型,知识库可以关联向量模型
This commit is contained in:
wangdan-fit2cloud 2024-07-19 02:21:24 -07:00 committed by GitHub
commit d3d09b10ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
74 changed files with 1562 additions and 544 deletions

View File

@ -43,6 +43,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"),
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer
@ -56,6 +57,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
"""
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
@ -67,6 +69,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
:param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题
:param search_mode 检索模式
:param user_id 用户id
:return: 段落列表
"""
pass

View File

@ -13,22 +13,48 @@ from django.db.models import QuerySet
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.config.embedding_config import VectorStore, ModelManage
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Paragraph
from dataset.models import Paragraph, DataSet
from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model
def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return dataset_list[0].embedding_mode_id
class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
if len(dataset_id_list) == 0:
return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
embedding_model = EmbeddingModel.get_embedding_model()
model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id, user_id)
self.context['model_name'] = model.name
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
@ -101,7 +127,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
'run_time': self.context['run_time'],
'problem_text': step_args.get(
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
'model_name': EmbeddingModel.get_embedding_model().model_name,
'model_name': self.context.get('model_name'),
'message_tokens': 0,
'answer_tokens': 0,
'cost': 0

View File

@ -111,6 +111,7 @@ class FlowParamsSerializer(serializers.Serializer):
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id"))
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案"))

View File

@ -13,20 +13,50 @@ from django.db.models import QuerySet
from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import EmbeddingModel, VectorStore
from common.config.embedding_config import VectorStore, ModelManage
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Document, Paragraph
from dataset.models import Document, Paragraph, DataSet
from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id, user_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id):
raise Exception(f"无权限使用此模型:{model.name}")
return model
def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return dataset_list[0].embedding_mode_id
def get_none_result(question):
return NodeResult(
{'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
'directly_return': ''}, {})
class BaseSearchDatasetNode(ISearchDatasetStepNode):
def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
self.context['question'] = question
embedding_model = EmbeddingModel.get_embedding_model()
if len(dataset_id_list) == 0:
return get_none_result(question)
model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id, self.flow_params_serializer.data.get('user_id'))
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in
@ -37,7 +67,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
if embedding_list is None:
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
return get_none_result(question)
paragraph_list = self.list_paragraph(embedding_list, vector)
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
result = sorted(result, key=lambda p: p.get('similarity'), reverse=True)

View File

@ -0,0 +1,19 @@
# Generated by Django 4.2.13 on 2024-07-15 15:52
import application.models.application
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0009_application_type_application_work_flow_and_more'),
]
operations = [
migrations.AlterField(
model_name='chatrecord',
name='details',
field=models.JSONField(default=dict, encoder=application.models.application.DateEncoder, verbose_name='对话详情'),
),
]

View File

@ -26,7 +26,7 @@ from rest_framework import serializers
from application.flow.workflow_manage import Flow
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.config.embedding_config import VectorStore
from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
@ -36,7 +36,7 @@ from common.util.common import valid_license
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import DataSet, Document, Image
from dataset.serializers.common_serializers import list_paragraph
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
from embedding.models import SearchMode
from setting.models import AuthOperate
from setting.models.model_management import Model
@ -415,12 +415,13 @@ class ApplicationSerializer(serializers.Serializer):
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)]
model = get_embedding_model_by_dataset_id_list(dataset_id_list)
# 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
self.data.get('top_number'),
self.data.get('similarity'),
SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model())
model)
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
@ -522,12 +523,14 @@ class ApplicationSerializer(serializers.Serializer):
if not QuerySet(Application).filter(id=self.data.get('application_id')).exists():
raise AppApiException(500, '不存在的应用id')
def list_model(self, with_valid=True):
def list_model(self, model_type=None, with_valid=True):
if with_valid:
self.is_valid()
if model_type is None:
model_type = "LLM"
application = QuerySet(Application).filter(id=self.data.get("application_id")).first()
return ModelSerializer.Query(
data={'user_id': application.user_id}).list(
data={'user_id': application.user_id, 'model_type': model_type}).list(
with_valid=True)
def delete(self, with_valid=True):

View File

@ -88,7 +88,9 @@ class ChatInfo:
'no_references_setting': self.application.dataset_setting.get(
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
'status': 'ai_questioning',
'value': '{question}'}
'value': '{question}',
},
'user_id': self.application.user_id
}
@ -221,11 +223,13 @@ class ChatMessageSerializer(serializers.Serializer):
stream = self.data.get('stream')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
user_id = chat_info.application.user_id
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
{'history_chat_record': chat_info.chat_record_list, 'question': message,
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
'stream': stream,
're_chat': re_chat}, WorkFlowPostHandler(chat_info, client_id, client_type))
're_chat': re_chat,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type))
r = work_flow_manage.run()
return r

View File

@ -29,6 +29,7 @@ from application.models.api_key_model import ApplicationAccessToken
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
ModelSettingSerializer
from application.serializers.chat_message_serializers import ChatInfo
from common.config.embedding_config import ModelManage
from common.constants.permission_constants import RoleConstants
from common.db.search import native_search, native_page_search, page_search, get_dynamics_model
from common.event import ListenerManagement
@ -39,8 +40,10 @@ from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import rsa_long_decrypt
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from setting.models import Model
from setting.models_provider import get_model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR
@ -241,12 +244,7 @@ class ChatSerializers(serializers.Serializer):
application_id=application_id)]
chat_model = None
if model is not None:
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(
model.credential)),
streaming=True)
chat_model = ModelManage.get_model(str(model.id), lambda _id: get_model(model))
chat_id = str(uuid.uuid1())
chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list,
@ -259,6 +257,7 @@ class ChatSerializers(serializers.Serializer):
class OpenWorkFlowChat(serializers.Serializer):
work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流"))
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
def open(self):
self.is_valid(raise_exception=True)
@ -269,7 +268,8 @@ class ChatSerializers(serializers.Serializer):
dataset_setting={},
model_setting={},
problem_optimization=None,
type=ApplicationTypeChoices.WORK_FLOW
type=ApplicationTypeChoices.WORK_FLOW,
user_id=self.data.get('user_id')
)
work_flow_version = WorkFlowVersion(work_flow=work_flow)
chat_cache.set(chat_id,
@ -332,7 +332,8 @@ class ChatSerializers(serializers.Serializer):
application = Application(id=None, dialogue_number=3, model=model,
dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'),
problem_optimization=self.data.get('problem_optimization'))
problem_optimization=self.data.get('problem_optimization'),
user_id=user_id)
chat_cache.set(chat_id,
ChatInfo(chat_id, chat_model, dataset_id_list,
[str(document.id) for document in
@ -533,9 +534,10 @@ class ChatRecordSerializer(serializers.Serializer):
raise AppApiException(500, "文档id不正确")
@staticmethod
def post_embedding_paragraph(chat_record, paragraph_id):
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
# 发送向量化事件
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id)
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
return chat_record
@post(post_function=post_embedding_paragraph)
@ -573,7 +575,7 @@ class ChatRecordSerializer(serializers.Serializer):
chat_record.improve_paragraph_id_list.append(paragraph.id)
# 添加标注
chat_record.save()
return ChatRecordSerializerModel(chat_record).data, paragraph.id
return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id
class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))

View File

@ -67,6 +67,20 @@ class ApplicationApi(ApiMixin):
}
)
class Model(ApiMixin):
@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='application_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='应用id'),
openapi.Parameter(name='model_type', in_=openapi.IN_QUERY,
type=openapi.TYPE_STRING,
required=False,
description='模型类型'),
]
class ApiKey(ApiMixin):
@staticmethod
def get_request_params_api():

View File

@ -175,7 +175,7 @@ class Application(APIView):
@swagger_auto_schema(operation_summary="获取模型列表",
operation_id="获取模型列表",
tags=["应用"],
manual_parameters=ApplicationApi.ApiKey.get_request_params_api())
manual_parameters=ApplicationApi.Model.get_request_params_api())
@has_permissions(ViewPermission(
[RoleConstants.ADMIN, RoleConstants.USER],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
@ -185,7 +185,7 @@ class Application(APIView):
return result.success(
ApplicationSerializer.Operate(
data={'application_id': application_id,
'user_id': request.user.id}).list_model())
'user_id': request.user.id}).list_model(request.query_params.get('model_type')))
class Profile(APIView):
authentication_classes = [TokenAuth]

View File

@ -41,3 +41,7 @@ class MemCache(LocMemCache):
delete_keys.append(key)
for key in delete_keys:
self._delete(key)
def clear_timeout_data(self):
for key in self._cache.keys():
self.get(key)

View File

@ -6,33 +6,36 @@
@date2023/10/23 16:03
@desc:
"""
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
import time
from smartdoc.const import CONFIG
from common.cache.mem_cache import MemCache
class EmbeddingModel:
instance = None
class ModelManage:
cache = MemCache('model', {})
up_clear_time = time.time()
@staticmethod
def get_embedding_model():
"""
获取向量化模型
:return:
"""
if EmbeddingModel.instance is None:
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
device = CONFIG.get('EMBEDDING_DEVICE')
encode_kwargs = {'normalize_embeddings': True}
e = HuggingFaceEmbeddings(
model_name=model_name,
cache_folder=cache_folder,
model_kwargs={'device': device},
encode_kwargs=encode_kwargs,
)
EmbeddingModel.instance = e
return EmbeddingModel.instance
def get_model(_id, get_model):
model_instance = ModelManage.cache.get(_id)
if model_instance is None:
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 30)
return model_instance
# 续期
ModelManage.cache.touch(_id, timeout=60 * 30)
ModelManage.clear_timeout_cache()
return model_instance
@staticmethod
def clear_timeout_cache():
if time.time() - ModelManage.up_clear_time > 60:
ModelManage.cache.clear_timeout_data()
@staticmethod
def delete_key(_id):
if ModelManage.cache.has_key(_id):
ModelManage.cache.delete(_id)
class VectorStore:

View File

@ -14,14 +14,14 @@ embedding_thread_pool = ThreadPoolExecutor(3)
def poxy(poxy_function):
def inner(args):
work_thread_pool.submit(poxy_function, args)
def inner(args, **keywords):
work_thread_pool.submit(poxy_function, args, **keywords)
return inner
def embedding_poxy(poxy_function):
def inner(args):
embedding_thread_pool.submit(poxy_function, args)
def inner(args, **keywords):
embedding_thread_pool.submit(poxy_function, args, **keywords)
return inner

View File

@ -15,8 +15,9 @@ from typing import List
import django.db.models
from blinker import signal
from django.db.models import QuerySet
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model
from common.event.common import poxy, embedding_poxy
from common.util.file_util import get_file_content
@ -46,22 +47,26 @@ class SyncWebDocumentArgs:
class UpdateProblemArgs:
def __init__(self, problem_id: str, problem_content: str):
def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings):
self.problem_id = problem_id
self.problem_content = problem_content
self.embedding_model = embedding_model
class UpdateEmbeddingDatasetIdArgs:
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str):
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str, target_embedding_model: Embeddings):
self.paragraph_id_list = paragraph_id_list
self.target_dataset_id = target_dataset_id
self.target_embedding_model = target_embedding_model
class UpdateEmbeddingDocumentIdArgs:
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str):
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str,
target_embedding_model: Embeddings = None):
self.paragraph_id_list = paragraph_id_list
self.target_document_id = target_document_id
self.target_dataset_id = target_dataset_id
self.target_embedding_model = target_embedding_model
class ListenerManagement:
@ -84,16 +89,46 @@ class ListenerManagement:
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")
@staticmethod
def embedding_by_problem(args):
VectorStore.get_embedding_vector().save(**args)
def embedding_by_problem(args, embedding_model: Embeddings):
VectorStore.get_embedding_vector().save(**args, embedding=embedding_model)
@staticmethod
def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings):
try:
data_list = native_search(
{'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter(
**{'paragraph.id__in': paragraph_id_list}),
'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql')))
ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list,
embedding_model=embedding_model)
except Exception as e:
max_kb_error.error(f'查询向量数据:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
@staticmethod
@embedding_poxy
def embedding_by_paragraph(paragraph_id):
def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings):
max_kb.info(f'开始--->向量化段落:{paragraph_id_list}')
try:
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list)
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
finally:
QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status})
max_kb.info(f'结束--->向量化段落:{paragraph_id_list}')
@staticmethod
@embedding_poxy
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
"""
向量化段落 根据段落id
:param paragraph_id: 段落id
:return: None
@param paragraph_id: 段落id
@param embedding_model: 向量模型
"""
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
status = Status.success
@ -107,7 +142,7 @@ class ListenerManagement:
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
@ -117,12 +152,15 @@ class ListenerManagement:
@staticmethod
@embedding_poxy
def embedding_by_document(document_id):
def embedding_by_document(document_id, embedding_model: Embeddings):
"""
向量化文档
:param document_id: 文档id
@param document_id: 文档id
@param embedding_model 向量模型
:return: None
"""
if not try_lock('embedding' + str(document_id)):
return
max_kb.info(f"开始--->向量化文档:{document_id}")
QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding})
@ -138,7 +176,7 @@ class ListenerManagement:
# 删除文档向量数据
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
except Exception as e:
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
@ -148,21 +186,24 @@ class ListenerManagement:
**{'status': status, 'update_time': datetime.datetime.now()})
QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status})
max_kb.info(f"结束--->向量化文档:{document_id}")
un_lock('embedding' + str(document_id))
@staticmethod
@embedding_poxy
def embedding_by_dataset(dataset_id):
def embedding_by_dataset(dataset_id, embedding_model: Embeddings):
"""
向量化知识库
:param dataset_id: 知识库id
@param dataset_id: 知识库id
@param embedding_model 向量模型
:return: None
"""
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
try:
ListenerManagement.delete_embedding_by_dataset(dataset_id)
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
for document in document_list:
ListenerManagement.embedding_by_document(document.id)
ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model)
except Exception as e:
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
finally:
@ -224,14 +265,22 @@ class ListenerManagement:
@staticmethod
def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id})
if args.target_embedding_model is None:
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id})
else:
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
embedding_model=args.target_embedding_model)
@staticmethod
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'document_id': args.target_document_id,
'dataset_id': args.target_dataset_id})
if args.target_embedding_model is None:
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'document_id': args.target_document_id,
'dataset_id': args.target_dataset_id})
else:
ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list,
embedding_model=args.target_embedding_model)
@staticmethod
def delete_embedding_by_source_ids(source_ids: List[str]):
@ -245,11 +294,6 @@ class ListenerManagement:
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
@staticmethod
@poxy
def init_embedding_model(ages):
EmbeddingModel.get_embedding_model()
def run(self):
# 添加向量 根据问题id
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
@ -276,8 +320,7 @@ class ListenerManagement:
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
# 启动段落向量
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
# 初始化向量化模型
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
# 同步web站点知识库
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
# 同步web站点 文档

View File

@ -0,0 +1,21 @@
# Generated by Django 4.2.13 on 2024-07-17 13:56
import dataset.models.data_set
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('setting', '0005_model_permission_type'),
('dataset', '0005_file'),
]
operations = [
migrations.AddField(
model_name='dataset',
name='embedding_mode',
field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'),
),
]

View File

@ -9,9 +9,11 @@
import uuid
from django.db import models
from django.db.models import QuerySet
from common.db.sql_execute import select_one
from common.mixins.app_model_mixin import AppModelMixin
from setting.models import Model
from users.models import User
@ -33,6 +35,10 @@ class HitHandlingMethod(models.TextChoices):
directly_return = 'directly_return', '直接返回'
def default_model():
return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab')
class DataSet(AppModelMixin):
"""
数据集表
@ -43,7 +49,8 @@ class DataSet(AppModelMixin):
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
default=Type.base)
embedding_mode = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
default=default_model)
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta:

View File

@ -14,6 +14,7 @@ from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.config.embedding_config import ModelManage
from common.db.search import native_search
from common.db.sql_execute import update_execute
from common.exception.app_exception import AppApiException
@ -21,7 +22,8 @@ from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
from dataset.models import Paragraph, Problem, ProblemParagraphMapping
from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR
@ -130,3 +132,22 @@ class ProblemParagraphManage:
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
is_create], problem_paragraph_mapping_list
return result
def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return ModelManage.get_model(str(dataset_list[0].id),
lambda _id: get_model(dataset_list[0].embedding_mode))
def get_embedding_model_by_dataset_id(dataset_id: str):
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
return ModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
def get_embedding_model_by_dataset(dataset):
return ModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))

View File

@ -15,7 +15,6 @@ from functools import reduce
from typing import Dict, List
from urllib.parse import urlparse
from django.conf import settings
from django.contrib.postgres.fields import ArrayField
from django.core import validators
from django.db import transaction, models
@ -25,7 +24,7 @@ from drf_yasg import openapi
from rest_framework import serializers
from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.config.embedding_config import VectorStore
from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list
from common.event import ListenerManagement, SyncWebDatasetArgs
@ -37,7 +36,8 @@ from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from embedding.models import SearchMode
from setting.models import AuthOperate
@ -206,6 +206,8 @@ class DataSetSerializers(serializers.ModelSerializer):
max_length=256,
min_length=1)
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
documents = DocumentInstanceSerializer(required=False, many=True)
def is_valid(self, *, raise_exception=False):
@ -226,6 +228,8 @@ class DataSetSerializers(serializers.ModelSerializer):
max_length=256,
min_length=1)
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
file_list = serializers.ListSerializer(required=True,
error_messages=ErrMessage.list("文件列表"),
child=serializers.FileField(required=True,
@ -296,6 +300,8 @@ class DataSetSerializers(serializers.ModelSerializer):
min_length=1)
source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("Web 根地址"), )
embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型"))
selector = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("选择器"))
@ -347,6 +353,8 @@ class DataSetSerializers(serializers.ModelSerializer):
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title="向量模型id",
description="向量模型id"),
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url",
description="web站点url"),
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
@ -355,8 +363,9 @@ class DataSetSerializers(serializers.ModelSerializer):
@staticmethod
def post_embedding_dataset(document_list, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
# 发送向量化事件
ListenerManagement.embedding_by_dataset_signal.send(dataset_id)
ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model)
return document_list
def save_qa(self, instance: Dict, with_valid=True):
@ -365,7 +374,8 @@ class DataSetSerializers(serializers.ModelSerializer):
self.CreateQASerializers(data=instance).is_valid()
file_list = instance.get('file_list')
document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list])
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list}
dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list,
'embedding_mode_id': instance.get('embedding_mode_id')}
return self.save(dataset_instance, with_valid=True)
@valid_license(model=DataSet, count=50,
@ -381,7 +391,8 @@ class DataSetSerializers(serializers.ModelSerializer):
if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists():
raise AppApiException(500, "知识库名称重复!")
dataset = DataSet(
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id})
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
'embedding_mode_id': instance.get('embedding_mode_id')})
document_model_list = []
paragraph_model_list = []
@ -452,7 +463,8 @@ class DataSetSerializers(serializers.ModelSerializer):
dataset = DataSet(
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
'type': Type.web,
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector'),
'embedding_mode_id': instance.get('embedding_mode_id')}})
dataset.save()
ListenerManagement.sync_web_dataset_signal.send(
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),
@ -500,6 +512,8 @@ class DataSetSerializers(serializers.ModelSerializer):
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title='向量模型',
description='向量模型'),
'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据",
items=DocumentSerializers().Create.get_request_body_api()
)
@ -557,12 +571,13 @@ class DataSetSerializers(serializers.ModelSerializer):
QuerySet(Document).filter(
dataset_id=self.data.get('id'),
is_active=False)]
model = get_embedding_model_by_dataset_id(self.data.get('id'))
# 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
self.data.get('top_number'),
self.data.get('similarity'),
SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model())
model)
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),
@ -730,7 +745,8 @@ class DataSetSerializers(serializers.ModelSerializer):
def re_embedding(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'))
model = get_embedding_model_by_dataset_id(self.data.get('id'))
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model)
def list_application(self, with_valid=True):
if with_valid:
@ -769,6 +785,7 @@ class DataSetSerializers(serializers.ModelSerializer):
QuerySet(ApplicationDatasetMapping).filter(
dataset_id=self.data.get('id'))]))}
@transaction.atomic
def edit(self, dataset: Dict, user_id: str):
"""
修改知识库
@ -782,6 +799,8 @@ class DataSetSerializers(serializers.ModelSerializer):
raise AppApiException(500, "知识库名称重复!")
_dataset = QuerySet(DataSet).get(id=self.data.get("id"))
DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset)
if 'embedding_mode_id' in dataset:
_dataset.embedding_mode_id = dataset.get('embedding_mode_id')
if "name" in dataset:
_dataset.name = dataset.get("name")
if 'desc' in dataset:

View File

@ -41,7 +41,8 @@ from common.util.file_util import get_file_content
from common.util.fork import Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
@ -234,12 +235,17 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
meta={})
else:
document_list.update(dataset_id=target_dataset_id)
# 修改向量信息
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
[paragraph.id for paragraph in paragraph_list],
target_dataset_id))
model = None
if dataset.embedding_mode_id != target_dataset.embedding_mode_id:
model = get_embedding_model_by_dataset_id(target_dataset_id)
pid_list = [paragraph.id for paragraph in paragraph_list]
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id)
# 修改向量信息
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
pid_list,
target_dataset_id, model))
@staticmethod
def get_target_dataset_problem(target_dataset_id: str,
@ -392,7 +398,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
problem_paragraph_mapping_list) > 0 else None
# 向量化
if with_embedding:
ListenerManagement.embedding_by_document_signal.send(document_id)
model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
else:
document.status = Status.error
document.save()
@ -405,6 +412,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
class Operate(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id"))
@staticmethod
def get_request_params_api():
@ -530,7 +538,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get("document_id")
ListenerManagement.embedding_by_document_signal.send(document_id)
model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id'))
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
@transaction.atomic
def delete(self):
@ -599,8 +608,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return True
@staticmethod
def post_embedding(result, document_id):
ListenerManagement.embedding_by_document_signal.send(document_id)
def post_embedding(result, document_id, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
return result
@staticmethod
@ -646,7 +656,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_id = str(document_model.id)
return DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True), document_id
with_valid=True), document_id, dataset_id
@staticmethod
def get_sync_handler(dataset_id):
@ -803,9 +813,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
@staticmethod
def post_embedding(document_list):
def post_embedding(document_list, dataset_id):
for document_dict in document_list:
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'))
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model)
return document_list
@post(post_function=post_embedding)
@ -846,7 +857,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return [],
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')),
with_search_one=False), dataset_id
@staticmethod
def _batch_sync(document_id_list: List[str]):

View File

@ -20,9 +20,9 @@ from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
ProblemParagraphManage
ProblemParagraphManage, get_embedding_model_by_dataset_id, get_embedding_model_by_dataset
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType
@ -132,6 +132,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'))
problem_paragraph_mapping.save()
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
@ -140,7 +141,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
})
}, embedding_model=model)
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'),
@ -227,6 +228,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
problem_id=problem.id)
problem_paragraph_mapping.save()
if with_embedding:
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
@ -234,7 +236,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
})
}, embedding_model=model)
def un_association(self, with_valid=True):
if with_valid:
@ -336,10 +338,11 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['document_id'])
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
[paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id))
target_document_id, target_dataset_id, target_embedding_model=None))
# 修改段落信息
paragraph_list.update(document_id=target_document_id)
# 不同数据集迁移
@ -366,12 +369,19 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['problem_id', 'dataset_id', 'document_id'])
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
[paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id))
target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first()
dataset = QuerySet(DataSet).filter(id=dataset_id).first()
embedding_model = None
if target_dataset.embedding_mode_id != dataset.embedding_mode_id:
embedding_model = get_embedding_model_by_dataset(target_dataset)
pid_list = [paragraph.id for paragraph in paragraph_list]
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
pid_list,
target_document_id, target_dataset_id, target_embedding_model=embedding_model))
update_document_char_length(document_id)
update_document_char_length(target_document_id)
@ -454,13 +464,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
raise AppApiException(500, "段落id不存在")
@staticmethod
def post_embedding(paragraph, instance):
def post_embedding(paragraph, instance, dataset_id):
if 'is_active' in instance and instance.get('is_active') is not None:
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
s.send(paragraph.get('id'))
else:
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'))
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model)
return paragraph
@post(post_embedding)
@ -508,7 +519,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
_paragraph.save()
update_document_char_length(self.data.get('document_id'))
return self.one(), instance
return self.one(), instance, self.data.get('dataset_id')
def get_problem_list(self):
ProblemParagraphMapping(ProblemParagraphMapping)
@ -582,7 +593,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
# 修改长度
update_document_char_length(document_id)
if with_embedding:
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model)
return ParagraphSerializers.Operate(
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True)

View File

@ -20,7 +20,8 @@ from common.event import ListenerManagement, UpdateProblemArgs
from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import Problem, Paragraph, ProblemParagraphMapping
from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from smartdoc.conf import PROJECT_DIR
@ -157,6 +158,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
content = instance.get('content')
problem = QuerySet(Problem).filter(id=problem_id,
dataset_id=dataset_id).first()
QuerySet(DataSet).filter(id=dataset_id)
problem.content = content
problem.save()
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content))
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.update_problem_signal.send(UpdateProblemArgs(problem_id, content, model))

View File

@ -52,7 +52,6 @@ class Dataset(APIView):
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="创建QA知识库",
operation_id="创建QA知识库",
manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(),
responses=get_api_response(
DataSetSerializers.Create.CreateQASerializers.get_response_body_api()),

View File

@ -10,9 +10,8 @@ import threading
from abc import ABC, abstractmethod
from typing import List, Dict
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import EmbeddingModel
from common.util.common import sub_array
from embedding.models import SourceType, SearchMode
@ -51,7 +50,7 @@ class BaseVectorStore(ABC):
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding=None):
embedding: Embeddings):
"""
插入向量数据
:param source_id: 资源id
@ -64,13 +63,10 @@ class BaseVectorStore(ABC):
:param paragraph_id 段落id
:return: bool
"""
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler()
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
def batch_save(self, data_list: List[Dict], embedding=None):
def batch_save(self, data_list: List[Dict], embedding: Embeddings):
# 获取锁
lock.acquire()
try:
@ -80,8 +76,6 @@ class BaseVectorStore(ABC):
:param embedding: 向量化处理器
:return: bool
"""
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler()
result = sub_array(data_list)
for child_array in result:
@ -94,17 +88,17 @@ class BaseVectorStore(ABC):
@abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: HuggingFaceEmbeddings):
embedding: Embeddings):
pass
@abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
pass
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str],
is_active: bool,
embedding: HuggingFaceEmbeddings):
embedding: Embeddings):
if dataset_id_list is None or len(dataset_id_list) == 0:
return []
embedding_query = embedding.embed_query(query_text)
@ -123,7 +117,7 @@ class BaseVectorStore(ABC):
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: HuggingFaceEmbeddings):
embedding: Embeddings):
pass
@abstractmethod
@ -142,14 +136,6 @@ class BaseVectorStore(ABC):
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
pass
@abstractmethod
def embed_documents(self, text_list: List[str]):
pass
@abstractmethod
def embed_query(self, text: str):
pass
@abstractmethod
def delete_by_dataset_id(self, dataset_id: str):
pass

View File

@ -13,9 +13,8 @@ from abc import ABC, abstractmethod
from typing import Dict, List
from django.db.models import QuerySet
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import EmbeddingModel
from common.db.search import generate_sql_by_query_dict
from common.db.sql_execute import select_list
from common.util.file_util import get_file_content
@ -33,14 +32,6 @@ class PGVector(BaseVectorStore):
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
def embed_documents(self, text_list: List[str]):
embedding = EmbeddingModel.get_embedding_model()
return embedding.embed_documents(text_list)
def embed_query(self, text: str):
embedding = EmbeddingModel.get_embedding_model()
return embedding.embed_query(text)
def vector_is_create(self) -> bool:
# 项目启动默认是创建好的 不需要再创建
return True
@ -50,7 +41,7 @@ class PGVector(BaseVectorStore):
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: HuggingFaceEmbeddings):
embedding: Embeddings):
text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid1(),
dataset_id=dataset_id,
@ -64,7 +55,7 @@ class PGVector(BaseVectorStore):
embedding.save()
return True
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts)
embedding_list = [Embedding(id=uuid.uuid1(),
@ -83,7 +74,7 @@ class PGVector(BaseVectorStore):
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: HuggingFaceEmbeddings):
embedding: Embeddings):
if dataset_id_list is None or len(dataset_id_list) == 0:
return []
exclude_dict = {}

View File

@ -0,0 +1,46 @@
# Generated by Django 4.2.13 on 2024-07-15 15:23
import json
from django.db import migrations, models
from django.db.models import QuerySet
from common.util.rsa_util import rsa_long_encrypt
from setting.models import Status, PermissionType
from smartdoc.const import CONFIG
default_embedding_model_id = '42f63a3d-427e-11ef-b3ec-a8a1595801ab'
def save_default_embedding_model(apps, schema_editor):
ModelModel = apps.get_model('setting', 'Model')
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
credential = {'cache_folder': cache_folder}
model_credential_str = json.dumps(credential)
model = ModelModel(id=default_embedding_model_id, name='maxkb-embedding', status=Status.SUCCESS,
model_type="EMBEDDING", model_name=model_name, user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab',
provider='model_local_provider',
credential=rsa_long_encrypt(model_credential_str), meta={},
permission_type=PermissionType.PUBLIC)
model.save()
def reverse_code_embedding_model(apps, schema_editor):
ModelModel = apps.get_model('setting', 'Model')
QuerySet(ModelModel).filter(id=default_embedding_model_id).delete()
class Migration(migrations.Migration):
dependencies = [
('setting', '0004_alter_model_credential'),
]
operations = [
migrations.AddField(
model_name='model',
name='permission_type',
field=models.CharField(choices=[('PUBLIC', '公开'), ('PRIVATE', '私有')], default='PRIVATE', max_length=20,
verbose_name='权限类型'),
),
migrations.RunPython(save_default_embedding_model, reverse_code_embedding_model)
]

View File

@ -22,6 +22,13 @@ class Status(models.TextChoices):
DOWNLOAD = "DOWNLOAD", '下载中'
PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载'
class PermissionType(models.TextChoices):
PUBLIC = "PUBLIC", '公开'
PRIVATE = "PRIVATE", "私有"
class Model(AppModelMixin):
"""
@ -46,6 +53,9 @@ class Model(AppModelMixin):
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices,
default=PermissionType.PRIVATE)
class Meta:
db_table = "model"
unique_together = ['name', 'user_id']

View File

@ -6,3 +6,85 @@
@date2023/10/31 17:16
@desc:
"""
import json
from typing import Dict
from common.util.rsa_util import rsa_long_decrypt
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
def get_model_(provider, model_type, model_name, credential):
"""
获取模型实例
@param provider: 供应商
@param model_type: 模型类型
@param model_name: 模型名称
@param credential: 认证信息
@return: 模型实例
"""
model = get_provider(provider).get_model(model_type, model_name,
json.loads(
rsa_long_decrypt(credential)),
streaming=True)
return model
def get_model(model):
"""
获取模型实例
@param model: model 数据库Model实例对象
@return: 模型实例
"""
return get_model_(model.provider, model.model_type, model.model_name, model.credential)
def get_provider(provider):
"""
获取供应商实例
@param provider: 供应商字符串
@return: 供应商实例
"""
return ModelProvideConstants[provider].value
def get_model_list(provider, model_type):
"""
获取模型列表
@param provider: 供应商字符串
@param model_type: 模型类型
@return: 模型列表
"""
return get_provider(provider).get_model_list(model_type)
def get_model_credential(provider, model_type, model_name):
"""
获取模型认证实例
@param provider: 供应商字符串
@param model_type: 模型类型
@param model_name: 模型名称
@return: 认证实例对象
"""
return get_provider(provider).get_model_credential(model_type, model_name)
def get_model_type_list(provider):
"""
获取模型类型列表
@param provider: 供应商字符串
@return: 模型类型列表
"""
return get_provider(provider).get_model_type_list()
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
"""
校验模型认证参数
@param provider: 供应商字符串
@param model_type: 模型类型
@param model_name: 模型名称
@param model_credential: 模型认证数据
@param raise_exception: 是否抛出错误
@return: True|False
"""
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)

View File

@ -61,7 +61,7 @@ class IModelProvider(ABC):
def get_model_list(self, model_type):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return self.get_model_info_manage().get_model_list()
return self.get_model_info_manage().get_model_list_by_model_type(model_type)
def get_model_credential(self, model_type, model_name):
model_info = self.get_model_info_manage().get_model_info(model_type, model_name)
@ -191,6 +191,9 @@ class ModelInfoManage:
def get_model_list(self):
return [model.to_dict() for model in self.model_list]
def get_model_list_by_model_type(self, model_type):
return [model.to_dict() for model in self.model_list if model.model_type == model_type]
def get_model_type_list(self):
return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if
len([model for model in self.model_list if model.model_type == _type.name]) > 0]

View File

@ -18,6 +18,7 @@ from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
class ModelProvideConstants(Enum):
@ -31,3 +32,4 @@ class ModelProvideConstants(Enum):
model_xf_provider = XunFeiModelProvider()
model_deepseek_provider = DeepSeekModelProvider()
model_gemini_provider = GeminiModelProvider()
model_local_provider = LocalModelProvider()

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/7/10 17:48
@desc:
"""

View File

@ -0,0 +1,45 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/11 11:06
@desc:
"""
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
if not model_type == 'EMBEDDING':
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['cache_folder']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return model
cache_folder = forms.TextInputField('模型目录', required=True)

View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" t="1720668342208" class="icon" viewBox="0 0 1024 1024" version="1.1" p-id="9052" width="100%" height="100%"><path d="M512.2 475.7c-8.2-0.3-16.1-2.4-23.4-6.1L192.6 330.2c-24.9-11.1-36.1-40.3-25-65.2 5-11.2 13.9-20.1 25-25l281.5-133.2a89.43 89.43 0 0 1 76 0L831.7 240c24.9 11.1 36.1 40.3 25 65.2-5 11.2-13.9 20.1-25 25L535.5 469.5c-7.2 3.8-15.2 5.9-23.3 6.2z m-76.5 452.5c-7.6 0-15.1-1.9-21.8-5.5L146.3 797.2c-17-8.9-27.5-26.5-27.3-45.6v-320c0.1-18 9.7-34.5 25.1-43.7 14.3-8.1 31.8-8.1 46.1 0l267.1 125.4c16.1 8.7 26.4 25.4 27.1 43.7v320.4c-0.2 17.9-9.6 34.4-24.9 43.7-7.2 4.3-15.4 6.8-23.8 7.1z m152.9 0c-8.3 0-16.5-2.2-23.8-6.3-15.3-9.3-24.7-25.8-24.9-43.7V556.9c0.4-18.2 10.4-34.8 26.2-43.7L835 387c14.2-7.5 31.4-7.1 45.2 1.1 15.5 9.1 25 25.7 25.1 43.7v319.8c0.4 18.9-9.7 36.5-26.2 45.6L610.5 922.8c-6.8 3.6-14.3 5.5-21.9 5.4z" p-id="9053"/></svg>

After

Width:  |  Height:  |  Size: 931 B

View File

@ -0,0 +1,38 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file zhipu_model_provider.py
@date2024/04/19 13:5
@desc:
"""
import os
from typing import Dict
from pydantic import BaseModel
from common.exception.app_exception import AppApiException
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
from smartdoc.conf import PROJECT_DIR
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
LocalEmbeddingCredential(), LocalEmbedding)
model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese)
.append_default_model_info(embedding_text2vec_base_chinese)
.build())
class LocalModelProvider(IModelProvider):
def get_model_info_manage(self):
return model_info_manage
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_local_provider', name='本地模型', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'local_model_provider', 'icon',
'local_icon_svg')))

View File

@ -0,0 +1,22 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/11 14:06
@desc:
"""
from typing import Dict
from langchain_huggingface import HuggingFaceEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'),
model_kwargs={'device': model_credential.get('device')},
encode_kwargs={'normalize_embeddings': True},
)

View File

@ -0,0 +1,45 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 15:10
@desc:
"""
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
model_list = provider.get_base_model_list(model_credential.get('api_base'))
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
if len(exist) == 0:
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
return True
def encryption_dict(self, model_info: Dict[str, object]):
return model_info
def build_model(self, model_info: Dict[str, object]):
for key in ['model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
return self
api_base = forms.TextInputField('API 域名', required=True)

View File

@ -0,0 +1,48 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 15:02
@desc:
"""
from typing import Dict, List
from langchain_community.embeddings import OllamaEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return OllamaEmbedding(
model=model_name,
base_url=model_credential.get('api_base'),
)
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed documents using an Ollama deployed embedding model.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
instruction_pairs = [f"{text}" for text in texts]
embeddings = self._embed(instruction_pairs)
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed a query using a Ollama deployed embedding model.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
instruction_pair = f"{text}"
embedding = self._embed([instruction_pair])[0]
return embedding

View File

@ -20,7 +20,9 @@ from common.forms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage
from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential
from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential
from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding
from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel
from smartdoc.conf import PROJECT_DIR
@ -88,14 +90,25 @@ model_info_list = [
ModelInfo(
'phi3',
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel),
]
ollama_embedding_model_credential = OllamaEmbeddingModelCredential()
embedding_model_info = [
ModelInfo(
'nomic-embed-text',
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding),
]
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info_list(
embedding_model_info).append_default_model_info(
ModelInfo(
'phi3',
'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。',
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).build()
ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).append_default_model_info(ModelInfo(
'nomic-embed-text',
'一个具有大令牌上下文窗口的高性能开放嵌入模型。',
ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ).build()
def get_base_url(url: str):
@ -139,7 +152,6 @@ def convert(response_stream) -> Iterator[DownModelChunk]:
temp = ""
if len(temp) > 0:
print(temp)
rows = [t for t in temp.split("\n") if len(t) > 0]
for row in rows:
yield convert_to_down_model_chunk(row, index)
@ -154,9 +166,6 @@ class OllamaModelProvider(IModelProvider):
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
'ollama_icon_svg')))
def get_dialogue_number(self):
return 2
@staticmethod
def get_base_model_list(api_base):
base_url = get_base_url(api_base)
@ -165,7 +174,7 @@ class OllamaModelProvider(IModelProvider):
return r.json()
def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]:
api_base = model_credential.get('api_base')
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
r = requests.request(
method="POST",

View File

@ -0,0 +1,46 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 16:45
@desc:
"""
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=True):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['api_base', 'api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query('你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

View File

@ -0,0 +1,23 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/12 17:44
@desc:
"""
from typing import Dict
from langchain_community.embeddings import OpenAIEmbeddings
from setting.models_provider.base_model_provider import MaxKBBaseModel
class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return OpenAIEmbeddingModel(
api_key=model_credential.get('api_key'),
model=model_name,
openai_api_base=model_credential.get('api_base'),
)

View File

@ -11,7 +11,9 @@ import os
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
ModelTypeConst, ModelInfoManage
from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential
from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential
from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel
from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel
from smartdoc.conf import PROJECT_DIR
@ -58,11 +60,17 @@ model_info_list = [
ModelTypeConst.LLM, openai_llm_model_credential,
OpenAIChatModel)
]
open_ai_embedding_credential = OpenAIEmbeddingCredential()
model_info_embedding_list = [
ModelInfo('text-embedding-ada-002', '',
ModelTypeConst.EMBEDDING, open_ai_embedding_credential,
OpenAIEmbeddingModel)]
model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info(
ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential, OpenAIChatModel
)).build()
)).append_model_info_list(model_info_embedding_list).append_default_model_info(
model_info_embedding_list[0]).build()
class OpenAIModelProvider(IModelProvider):

View File

@ -7,12 +7,14 @@
@desc:
"""
import json
import re
import threading
import time
import uuid
from typing import Dict
from django.db.models import QuerySet
from django.core import validators
from django.db.models import QuerySet, Q
from rest_framework import serializers
from application.models import Application
@ -36,6 +38,9 @@ class ModelPullManage:
for chunk in response:
down_model_chunk[chunk.digest] = chunk.to_dict()
if time.time() - timestamp > 5:
model_new = QuerySet(Model).filter(id=model.id).first()
if model_new.status == Status.PAUSE_DOWNLOAD:
return
QuerySet(Model).filter(id=model.id).update(
meta={"down_model_chunk": list(down_model_chunk.values())})
timestamp = time.time()
@ -72,7 +77,7 @@ class ModelSerializer(serializers.Serializer):
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
name = self.data.get('name')
model_query_set = QuerySet(Model).filter(user_id=user_id)
model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC')))
query_params = {}
if name is not None:
query_params['name__contains'] = name
@ -85,7 +90,8 @@ class ModelSerializer(serializers.Serializer):
return [
{'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name, 'status': model.status, 'meta': model.meta} for model in
'model_name': model.model_name, 'status': model.status, 'meta': model.meta,
'permission_type': model.permission_type} for model in
model_query_set.filter(**query_params).order_by("-create_time")]
class Edit(serializers.Serializer):
@ -96,6 +102,11 @@ class ModelSerializer(serializers.Serializer):
model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"), validators=[
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
message="权限只支持PUBLIC|PRIVATE", code=500)
])
model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型"))
credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息"))
@ -135,6 +146,11 @@ class ModelSerializer(serializers.Serializer):
model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型"))
permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[
validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"),
message="权限只支持PUBLIC|PRIVATE", code=500)
])
model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型"))
credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息"))
@ -165,10 +181,12 @@ class ModelSerializer(serializers.Serializer):
provider = self.data.get('provider')
model_type = self.data.get('model_type')
model_name = self.data.get('model_name')
permission_type = self.data.get('permission_type')
model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=rsa_long_encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name)
provider=provider, model_type=model_type, model_name=model_name,
permission_type=permission_type)
model.save()
if status == Status.DOWNLOAD:
thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential))
@ -184,7 +202,8 @@ class ModelSerializer(serializers.Serializer):
'meta': model.meta,
'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type,
model.model_name).encryption_dict(
credential)}
credential),
'permission_type': model.permission_type}
class Operate(serializers.Serializer):
id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
@ -210,7 +229,8 @@ class ModelSerializer(serializers.Serializer):
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name,
'status': model.status,
'meta': model.meta, }
'meta': model.meta
}
def delete(self, with_valid=True):
if with_valid:
@ -221,6 +241,12 @@ class ModelSerializer(serializers.Serializer):
QuerySet(Model).filter(id=self.data.get('id')).delete()
return True
def pause_download(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD)
return True
def edit(self, instance: Dict, user_id: str, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
@ -245,7 +271,7 @@ class ModelSerializer(serializers.Serializer):
model.status = Status.DOWNLOAD
else:
raise e
update_keys = ['credential', 'name', 'model_type', 'model_name']
update_keys = ['credential', 'name', 'model_type', 'model_name', 'permission_type']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential':

View File

@ -74,6 +74,8 @@ class ModelCreateApi(ApiMixin):
'provider': openapi.Schema(type=openapi.TYPE_STRING,
title="供应商",
description="供应商"),
'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限",
description="PUBLIC|PRIVATE"),
'model_type': openapi.Schema(type=openapi.TYPE_STRING,
title="供应商",
description="供应商"),
@ -82,7 +84,8 @@ class ModelCreateApi(ApiMixin):
description="供应商"),
'credential': openapi.Schema(type=openapi.TYPE_OBJECT,
title="模型证书信息",
description="模型证书信息")
description="模型证书信息"),
}
)

View File

@ -16,6 +16,7 @@ urlpatterns = [
name="provider/model_form"),
path('model', views.Model.as_view(), name='model'),
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
path('model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting'),
path('valid/<str:valid_type>/<int:valid_count>', views.Valid.as_view())

View File

@ -69,6 +69,18 @@ class Model(APIView):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True))
class PauseDownload(APIView):
authentication_classes = [TokenAuth]
@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="暂停模型下载",
operation_id="暂停模型下载",
tags=["模型"])
@has_permissions(PermissionConstants.MODEL_CREATE)
def put(self, request: Request, model_id: str):
return result.success(
ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).pause_download())
class Operate(APIView):
authentication_classes = [TokenAuth]

View File

@ -104,9 +104,6 @@ CACHES = {
"token_cache": {
'BACKEND': 'common.cache.file_cache.FileCache',
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
},
"chat_cache": {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
}
}

View File

@ -129,11 +129,12 @@ const getDatasetDetail: (dataset_id: string, loading?: Ref<boolean>) => Promise<
"desc": true
}
*/
const putDataset: (dataset_id: string, data: any) => Promise<Result<any>> = (
dataset_id,
data: any
) => {
return put(`${prefix}/${dataset_id}`, data)
const putDataset: (
dataset_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (dataset_id, data, loading) => {
return put(`${prefix}/${dataset_id}`, data, undefined, loading)
}
/**
*

View File

@ -130,7 +130,18 @@ const getModelMetaById: (model_id: string, loading?: Ref<boolean>) => Promise<Re
) => {
return get(`${prefix}/${model_id}/meta`, {}, loading)
}
/**
*
* @param model_id id
* @param loading
* @returns
*/
const pauseDownload: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
model_id,
loading
) => {
return put(`${prefix}/${model_id}/pause_download`, undefined, {}, loading)
}
const deleteModel: (model_id: string, loading?: Ref<boolean>) => Promise<Result<boolean>> = (
model_id,
loading
@ -147,5 +158,6 @@ export default {
updateModel,
deleteModel,
getModelById,
getModelMetaById
getModelMetaById,
pauseDownload
}

View File

@ -3,6 +3,7 @@ interface datasetData {
desc: String
documents?: Array<any>
type?: String
embedding_mode_id?: String
}
export type { datasetData }

View File

@ -53,6 +53,7 @@ interface Model {
*
*/
model_type: string
permission_type: 'PUBLIC' | 'PRIVATE'
/**
*
*/
@ -68,7 +69,7 @@ interface Model {
/**
*
*/
status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR'
status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR' | 'PAUSE_DOWNLOAD'
/**
*
*/

View File

@ -18,7 +18,8 @@
</slot>
<slot></slot>
</div>
<el-checkbox v-bind:modelValue="modelValue.includes(toModelValue)"> </el-checkbox>
<el-checkbox v-bind:modelValue="modelValue.includes(toModelValue)" @change="checkboxChange">
</el-checkbox>
</div>
</el-card>
</template>
@ -40,7 +41,7 @@ const toModelValue = computed(() => (props.valueField ? props.data[props.valueFi
// set: (val) => val
// })
const emit = defineEmits(['update:modelValue'])
const emit = defineEmits(['update:modelValue', 'change'])
const checked = () => {
const value = props.modelValue ? props.modelValue : []
@ -53,6 +54,10 @@ const checked = () => {
emit('update:modelValue', [...value, toModelValue.value])
}
}
function checkboxChange() {
emit('change')
}
</script>
<style lang="scss" scoped>
.card-checkbox {

View File

@ -0,0 +1,93 @@
<template>
<div class="loading-container loader">
<div class="download-loading">
<div></div>
<div></div>
<div></div>
<div></div>
<div></div>
<div></div>
<div></div>
<div></div>
</div>
</div>
</template>
<script setup lang="ts"></script>
<style lang="scss" scoped>
.loading-container {
display: -webkit-flex; /*safari弹性布局*/
justify-content: center;
display: flex;
align-items: center;
}
@-webkit-keyframes loader {
0% {
opacity: 0.3;
}
80% {
opacity: 1;
}
100% {
opacity: 1;
}
}
.download-loading {
position: relative;
}
.download-loading div {
width: 5px;
height: 12px;
background: var(--el-color-info);
position: absolute;
border-radius: 2px;
margin: 0 auto;
}
.download-loading div:nth-child(1) {
top: -20px;
left: 0;
-webkit-animation: loader 1s -0.8s infinite ease-in-out;
}
.download-loading div:nth-child(2) {
top: -13px;
left: 13px;
-webkit-transform: rotate(45deg);
-webkit-animation: loader 1s -0.6s infinite ease-in-out;
}
.download-loading div:nth-child(3) {
top: 0px;
left: 20px;
-webkit-transform: rotate(90deg);
-webkit-animation: loader 1s -0.5s infinite ease-in-out;
}
.download-loading div:nth-child(4) {
top: 13px;
left: 13px;
-webkit-transform: rotate(-45deg);
-webkit-animation: loader 1s -0.4s infinite ease-in-out;
}
.download-loading div:nth-child(5) {
top: 20px;
left: 0px;
-webkit-transform: rotate(0deg);
-webkit-animation: loader 1s -0.3s infinite ease-in-out;
}
.download-loading div:nth-child(6) {
top: 13px;
left: -13px;
-webkit-transform: rotate(45deg);
-webkit-animation: loader 1s -0.2s infinite ease-in-out;
}
.download-loading div:nth-child(7) {
top: 0px;
left: -20px;
-webkit-transform: rotate(90deg);
-webkit-animation: loader 1s -0.1s infinite ease-in-out;
}
.download-loading div:nth-child(8) {
top: -13px;
left: -13px;
-webkit-transform: rotate(-45deg);
-webkit-animation: loader 1s 0s infinite ease-in-out;
}
</style>

8
ui/src/enums/model.ts Normal file
View File

@ -0,0 +1,8 @@
export enum PermissionType {
PRIVATE = '私有',
PUBLIC = '公用'
}
export enum PermissionDesc {
PRIVATE = '仅自己使用',
PUBLIC = '所有用户都可使用,不能编辑'
}

View File

@ -99,7 +99,7 @@
</div>
</template>
<template v-else-if="isDataset">
<div class="w-full text-left cursor" @click="router.push({ path: '/dataset/create' })">
<div class="w-full text-left cursor" @click="openCreateDialog">
<el-button link>
<el-icon class="mr-4"><Plus /></el-icon>
</el-button>
@ -110,12 +110,14 @@
</el-dropdown>
</div>
<CreateApplicationDialog ref="CreateApplicationDialogRef" @refresh="refresh" />
<CreateDatasetDialog ref="CreateDatasetDialogRef" @refresh="refresh" />
</template>
<script setup lang="ts">
import { ref, onMounted, computed } from 'vue'
import { onBeforeRouteLeave, useRouter, useRoute } from 'vue-router'
import CreateApplicationDialog from '@/views/application/component/CreateApplicationDialog.vue'
import CreateDatasetDialog from '@/views/dataset/component/CreateDatasetDialog.vue'
import { isAppIcon, isWorkFlow } from '@/utils/application'
import useStore from '@/stores'
const { common, dataset, application } = useStore()
@ -130,6 +132,7 @@ onBeforeRouteLeave((to, from) => {
common.saveBreadcrumb(null)
})
const CreateDatasetDialogRef = ref()
const CreateApplicationDialogRef = ref()
const list = ref<any[]>([])
const loading = ref(false)
@ -148,7 +151,11 @@ const isDataset = computed(() => {
})
function openCreateDialog() {
CreateApplicationDialogRef.value.open()
if (isDataset.value) {
CreateDatasetDialogRef.value.open()
} else if (isApplication.value) {
CreateApplicationDialogRef.value.open()
}
}
function changeMenu(id: string) {

View File

@ -13,9 +13,9 @@ const datasetRouter = {
},
{
path: '/dataset/:type', // create 或者 upload
name: 'CreateDataset',
name: 'UploadDocumentDataset',
meta: { activeMenu: '/dataset' },
component: () => import('@/views/dataset/CreateDataset.vue'),
component: () => import('@/views/dataset/UploadDocumentDataset.vue'),
hidden: true
},
{

View File

@ -1,11 +1,11 @@
import { defineStore } from 'pinia'
import modelApi from '@/api/model'
import type { modelRequest, Provider } from '@/api/type/model'
import type { ListModelRequest, Provider } from '@/api/type/model'
const useModelStore = defineStore({
id: 'model',
state: () => ({}),
actions: {
async asyncGetModel(data?: modelRequest) {
async asyncGetModel(data?: ListModelRequest) {
return new Promise((resolve, reject) => {
modelApi
.getModel(data)

View File

@ -377,6 +377,11 @@ h5 {
color: var(--el-color-primary);
border: none;
}
.danger-tag {
background: var(--tag-danger-bg);
color: #d03f3b;
border: none;
}
.success-tag {
background: var(--tag-success-bg);
color: var(--el-color-success);
@ -388,6 +393,12 @@ h5 {
border: none;
}
.info-tag {
background: var(--app-text-color-light-1);
color: var(--app-text-color-secondary);
border: none;
}
.purple-tag {
background: #f2ebfe;
color: #7f3bf5;

View File

@ -30,6 +30,7 @@
--tag-success-color: #2ca91f;
--tag-warning-bg: rgba(255, 136, 0, 0.2);
--tag-warning-color: #d97400;
--tag-danger-bg: rgba(245, 74, 69, 0.2);
/** card */
--card-width: 330px;

View File

@ -4,21 +4,28 @@
v-model="dialogVisible"
width="600"
append-to-body
class="addDataset-dialog"
>
<template #header="{ titleId, titleClass }">
<div class="my-header flex">
<div class="flex-between mb-8">
<h4 :id="titleId" :class="titleClass">
{{ $t('views.application.applicationForm.dialogues.addDataset') }}
</h4>
<el-button link class="ml-16" @click="refresh">
<el-icon class="mr-4"><Refresh /></el-icon
>{{ $t('views.application.applicationForm.dialogues.refresh') }}
</el-button>
<div class="flex align-center">
<el-button link class="ml-16" @click="refresh">
<el-icon class="mr-4"><Refresh /></el-icon
>{{ $t('views.application.applicationForm.dialogues.refresh') }}
</el-button>
<el-divider direction="vertical" />
</div>
</div>
<el-text type="info" class="color-secondary">
所选知识库必须使用相同的 Embedding 模型
</el-text>
</template>
<el-row :gutter="12" v-loading="loading">
<el-col :span="12" v-for="(item, index) in data" :key="index" class="mb-16">
<CardCheckbox value-field="id" :data="item" v-model="checkList">
<el-col :span="12" v-for="(item, index) in filterData" :key="index" class="mb-16">
<CardCheckbox value-field="id" :data="item" v-model="checkList" @change="changeHandle">
<span class="ellipsis">
{{ item.name }}
</span>
@ -26,19 +33,29 @@
</el-col>
</el-row>
<template #footer>
<span class="dialog-footer">
<el-button @click.prevent="dialogVisible = false">
{{ $t('views.application.applicationForm.buttons.cancel') }}
</el-button>
<el-button type="primary" @click="submitHandle">
{{ $t('views.application.applicationForm.buttons.confirm') }}
</el-button>
</span>
<div class="flex-between">
<div>
<el-text type="info" class="color-secondary" v-if="checkList.length > 0">
已选 {{ checkList.length }} 个知识库
</el-text>
<el-button link type="primary" v-if="checkList.length > 0" @click="clearCheck">
清空
</el-button>
</div>
<span>
<el-button @click.prevent="dialogVisible = false">
{{ $t('views.application.applicationForm.buttons.cancel') }}
</el-button>
<el-button type="primary" @click="submitHandle">
{{ $t('views.application.applicationForm.buttons.confirm') }}
</el-button>
</span>
</div>
</template>
</el-dialog>
</template>
<script setup lang="ts">
import { ref, watch } from 'vue'
import { computed, ref, watch } from 'vue'
const props = defineProps({
data: {
type: Array<any>,
@ -51,6 +68,13 @@ const emit = defineEmits(['addData', 'refresh'])
const dialogVisible = ref<boolean>(false)
const checkList = ref([])
const currentEmbedding = ref('')
const filterData = computed(() => {
return currentEmbedding.value
? props.data.filter((v) => v.embedding_mode_id === currentEmbedding.value)
: props.data
})
watch(dialogVisible, (bool) => {
if (!bool) {
@ -58,6 +82,18 @@ watch(dialogVisible, (bool) => {
}
})
function changeHandle() {
if (checkList.value.length === 1) {
currentEmbedding.value = props.data.filter(
(v) => v.id === checkList.value[0]
)[0].embedding_mode_id
}
}
function clearCheck() {
checkList.value = []
currentEmbedding.value = ''
}
const open = (checked: any) => {
checkList.value = checked
dialogVisible.value = true
@ -73,4 +109,13 @@ const refresh = () => {
defineExpose({ open })
</script>
<style lang="scss" scope></style>
<style lang="scss" scope>
.addDataset-dialog {
.el-dialog__header.show-close {
padding-right: 15px;
}
.el-dialog__headerbtn {
top: 13px;
}
}
</style>

View File

@ -63,10 +63,10 @@
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click.prevent="dialogVisible = false">
<el-button @click.prevent="dialogVisible = false" :loading="loading">
{{ $t('views.application.applicationForm.buttons.cancel') }}
</el-button>
<el-button type="primary" @click="submitValid(applicationFormRef)">
<el-button type="primary" @click="submitValid(applicationFormRef)" :loading="loading">
{{ $t('views.application.applicationForm.buttons.create') }}
</el-button>
</span>
@ -183,10 +183,7 @@ const submitValid = (formEl: FormInstance | undefined) => {
if (res?.data) {
submitHandle(formEl)
} else {
MsgAlert(
'提示',
'社区版最多支持 5 个应用如需拥有更多应用请联系我们https://fit2cloud.com/)。'
)
MsgAlert('提示', '社区版最多支持 5 个应用,如需拥有更多应用,请升级为专业版。')
}
})
}

View File

@ -1,5 +1,5 @@
<template>
<div class="authentication-setting p-24">
<div class="authentication-setting p-16-24">
<h4>{{ $t('login.authentication') }}</h4>
<el-tabs v-model="activeName" class="demo-tabs" @tab-click="handleClick">
<template v-for="(item, index) in tabList" :key="index">
@ -38,7 +38,7 @@ onMounted(() => {})
background-color: var(--app-view-bg-color);
box-sizing: border-box;
min-width: 700px;
height: calc(100vh - var(--app-header-height) - var(--app-view-padding) * 2 - 80px);
height: calc(100vh - var(--app-header-height) - var(--app-view-padding) * 2 - 70px);
box-sizing: border-box;
.form-container {
width: 70%;

View File

@ -3,6 +3,7 @@
<div class="dataset-setting main-calc-height">
<el-scrollbar>
<div class="p-24" v-loading="loading">
<h4 class="title-decoration-1 mb-16">基本信息</h4>
<BaseForm ref="BaseFormRef" :data="detail" />
<el-form
@ -104,7 +105,7 @@ import { useRoute } from 'vue-router'
import BaseForm from '@/views/dataset/component/BaseForm.vue'
import datasetApi from '@/api/dataset'
import type { ApplicationFormType } from '@/api/type/application'
import { MsgSuccess } from '@/utils/message'
import { MsgSuccess, MsgConfirm } from '@/utils/message'
import { isAppIcon } from '@/utils/application'
import useStore from '@/stores'
const route = useRoute()
@ -119,6 +120,8 @@ const loading = ref(false)
const detail = ref<any>({})
const application_list = ref<Array<ApplicationFormType>>([])
const application_id_list = ref([])
const cloneModelId = ref('')
const form = ref<any>({
source_url: '',
selector: ''
@ -132,7 +135,6 @@ async function submit() {
if (await BaseFormRef.value?.validate()) {
await webFormRef.value.validate((valid: any) => {
if (valid) {
loading.value = true
const obj =
detail.value.type === '1'
? {
@ -144,15 +146,25 @@ async function submit() {
application_id_list: application_id_list.value,
...BaseFormRef.value.form
}
datasetApi
.putDataset(id, obj)
.then((res) => {
if (cloneModelId.value !== BaseFormRef.value.form.embedding_mode_id) {
MsgConfirm(`提示`, `修改知识库向量模型后,需要对知识库重新向量化,是否继续保存?`, {
confirmButtonText: '重新向量化',
confirmButtonClass: 'primary'
})
.then(() => {
datasetApi.putDataset(id, obj, loading).then((res) => {
datasetApi.putReEmbeddingDataset(id).then(() => {
MsgSuccess('保存成功')
})
})
})
.catch(() => {})
} else {
datasetApi.putDataset(id, obj, loading).then((res) => {
MsgSuccess('保存成功')
loading.value = false
})
.catch(() => {
loading.value = false
})
}
}
})
}
@ -161,10 +173,10 @@ async function submit() {
function getDetail() {
dataset.asyncGetDatasetDetail(id, loading).then((res: any) => {
detail.value = res.data
cloneModelId.value = res.data?.embedding_mode_id
if (detail.value.type === '1') {
form.value = res.data.meta
}
application_id_list.value = res.data?.application_id_list
datasetApi.listUsableApplication(id, loading).then((ok) => {
application_list.value = ok.data

View File

@ -1,15 +1,18 @@
<template>
<LayoutContainer :header="isCreate ? '创建知识库' : '上传文档'" class="create-dataset">
<LayoutContainer header="上传文档" class="create-dataset">
<template #backButton>
<back-button @click="back"></back-button>
</template>
<div class="create-dataset__main flex" v-loading="loading">
<div class="create-dataset__component main-calc-height">
<template v-if="active === 0">
<StepFirst ref="StepFirstRef" />
<div class="upload-document p-24">
<!-- 上传文档 -->
<UploadComponent ref="UploadComponentRef" />
</div>
</template>
<template v-else-if="active === 1">
<StepSecond ref="StepSecondRef" />
<SetRules ref="SetRulesRef" />
</template>
<template v-else-if="active === 2">
<ResultSuccess :data="successInfo" />
@ -19,12 +22,7 @@
<div class="create-dataset__footer text-right border-t" v-if="active !== 2">
<el-button @click="router.go(-1)" :disabled="loading">取消</el-button>
<el-button @click="prev" v-if="active === 1" :disabled="loading">上一步</el-button>
<el-button
@click="next"
type="primary"
v-if="active === 0"
:disabled="loading || StepFirstRef?.loading"
>
<el-button @click="next" type="primary" v-if="active === 0" :disabled="loading">
创建并导入
</el-button>
<el-button @click="submit" type="primary" v-if="active === 1" :disabled="loading">
@ -36,9 +34,9 @@
<script setup lang="ts">
import { ref, computed, onUnmounted } from 'vue'
import { useRouter, useRoute } from 'vue-router'
import StepFirst from './step/StepFirst.vue'
import StepSecond from './step/StepSecond.vue'
import ResultSuccess from './step/ResultSuccess.vue'
import SetRules from './component/SetRules.vue'
import ResultSuccess from './component/ResultSuccess.vue'
import UploadComponent from './component/UploadComponent.vue'
import datasetApi from '@/api/dataset'
import documentApi from '@/api/document'
import type { datasetData } from '@/api/type/dataset'
@ -46,33 +44,17 @@ import { MsgConfirm, MsgSuccess } from '@/utils/message'
import useStore from '@/stores'
const { dataset, document } = useStore()
const baseInfo = computed(() => dataset.baseInfo)
const webInfo = computed(() => dataset.webInfo)
const documentsFiles = computed(() => dataset.documentsFiles)
const documentsType = computed(() => dataset.documentsType)
const router = useRouter()
const route = useRoute()
const {
params: { type },
query: { id } // iddatasetIDid
} = route
const isCreate = type === 'create'
// const steps = [
// {
// ref: 'StepFirstRef',
// name: '',
// component: StepFirst
// },
// {
// ref: 'StepSecondRef',
// name: '',
// component: StepSecond
// }
// ]
const StepFirstRef = ref()
const StepSecondRef = ref()
const SetRulesRef = ref()
const UploadComponentRef = ref()
const loading = ref(false)
const disabled = ref(false)
@ -81,7 +63,7 @@ const successInfo = ref<any>(null)
async function next() {
disabled.value = true
if (await StepFirstRef.value?.onSubmit()) {
if (await UploadComponentRef.value.validate()) {
if (documentsType.value === 'QA') {
let fd = new FormData()
documentsFiles.value.forEach((item: any) => {
@ -118,16 +100,14 @@ const prev = () => {
}
function clearStore() {
dataset.saveBaseInfo(null)
dataset.saveWebInfo(null)
dataset.saveDocumentsFile([])
dataset.saveDocumentsType('')
}
function submit() {
loading.value = true
const documents = [] as any
StepSecondRef.value?.paragraphList.map((item: any) => {
if (!StepSecondRef.value?.checkedConnect) {
SetRulesRef.value?.paragraphList.map((item: any) => {
if (!SetRulesRef.value?.checkedConnect) {
item.content.map((v: any) => {
delete v['problem_list']
})
@ -159,7 +139,7 @@ function submit() {
}
}
function back() {
if (baseInfo.value || webInfo.value || documentsFiles.value?.length > 0) {
if (documentsFiles.value?.length > 0) {
MsgConfirm(`提示`, `当前的更改尚未保存,确认退出吗?`, {
confirmButtonText: '确认',
type: 'warning'
@ -206,5 +186,10 @@ onUnmounted(() => {
width: 100%;
box-sizing: border-box;
}
.upload-document {
width: 70%;
margin: 0 auto;
margin-bottom: 20px;
}
}
</style>

View File

@ -1,11 +1,11 @@
<template>
<h4 class="title-decoration-1 mb-16">基本信息</h4>
<el-form
ref="FormRef"
:model="form"
:rules="rules"
label-position="top"
require-asterisk-position="right"
v-loading="loading"
>
<el-form-item label="知识库名称" prop="name">
<el-input
@ -27,14 +27,72 @@
@blur="form.desc = form.desc.trim()"
/>
</el-form-item>
<el-form-item label="Embedding模型" prop="embedding_mode_id">
<el-select
v-model="form.embedding_mode_id"
placeholder="请选择Embedding模型"
class="w-full"
popper-class="select-model"
:clearable="true"
>
<el-option-group
v-for="(value, label) in modelOptions"
:key="value"
:label="relatedObject(providerOptions, label, 'provider')?.name"
>
<el-option
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
:key="item.id"
:label="item.name"
:value="item.id"
class="flex-between"
>
<div class="flex">
<span
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
class="model-icon mr-8"
></span>
<span>{{ item.name }}</span>
</div>
<el-icon class="check-icon" v-if="item.id === form.embedding_mode_id"
><Check
/></el-icon>
</el-option>
<!-- 不可用 -->
<el-option
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
:key="item.id"
:label="item.name"
:value="item.id"
class="flex-between"
disabled
>
<div class="flex">
<span
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
class="model-icon mr-8"
></span>
<span>{{ item.name }}</span>
<span class="danger">{{
$t('views.application.applicationForm.form.aiModel.unavailable')
}}</span>
</div>
<el-icon class="check-icon" v-if="item.id === form.embedding_mode_id"
><Check
/></el-icon>
</el-option>
</el-option-group>
</el-select>
</el-form-item>
</el-form>
</template>
<script setup lang="ts">
import { ref, reactive, onMounted, onUnmounted, computed, watch } from 'vue'
import { useRoute } from 'vue-router'
import { groupBy } from 'lodash'
import useStore from '@/stores'
import type { datasetData } from '@/api/type/dataset'
import { isAllPropertiesEmpty } from '@/utils/utils'
import { relatedObject } from '@/utils/utils'
import type { Provider } from '@/api/type/model'
const props = defineProps({
data: {
@ -42,23 +100,23 @@ const props = defineProps({
default: () => {}
}
})
const route = useRoute()
const {
params: { type }
} = route
const isCreate = type === 'create'
const { dataset } = useStore()
const baseInfo = computed(() => dataset.baseInfo)
const { model } = useStore()
const form = ref<datasetData>({
name: '',
desc: ''
desc: '',
embedding_mode_id: ''
})
const rules = reactive({
name: [{ required: true, message: '请输入知识库名称', trigger: 'blur' }],
desc: [{ required: true, message: '请输入知识库描述', trigger: 'blur' }]
desc: [{ required: true, message: '请输入知识库描述', trigger: 'blur' }],
embedding_mode_id: [{ required: true, message: '请输入Embedding模型', trigger: 'change' }]
})
const FormRef = ref()
const loading = ref(false)
const modelOptions = ref<any>([])
const providerOptions = ref<Array<Provider>>([])
watch(
() => props.data,
@ -66,23 +124,13 @@ watch(
if (value && JSON.stringify(value) !== '{}') {
form.value.name = value.name
form.value.desc = value.desc
form.value.embedding_mode_id = value.embedding_mode_id
}
},
{
immediate: true
}
)
watch(form.value, (value) => {
if (isAllPropertiesEmpty(value)) {
dataset.saveBaseInfo(null)
} else {
if (isCreate) {
dataset.saveBaseInfo(value)
}
}
})
/*
表单校验
*/
@ -93,16 +141,43 @@ function validate() {
})
}
function getModel() {
loading.value = true
model
.asyncGetModel({ model_type: 'EMBEDDING' })
.then((res: any) => {
modelOptions.value = groupBy(res?.data, 'provider')
loading.value = false
})
.catch(() => {
loading.value = false
})
}
function getProvider() {
loading.value = true
model
.asyncGetProvider()
.then((res: any) => {
providerOptions.value = res?.data
loading.value = false
})
.catch(() => {
loading.value = false
})
}
onMounted(() => {
if (baseInfo.value) {
form.value = baseInfo.value
}
getProvider()
getModel()
})
onUnmounted(() => {
form.value = {
name: '',
desc: ''
desc: '',
embedding_mode_id: ''
}
FormRef.value?.clearValidate()
})
defineExpose({

View File

@ -0,0 +1,177 @@
<template>
<el-dialog title="创建知识库" v-model="dialogVisible" width="650" append-to-body>
<!-- 基本信息 -->
<BaseForm ref="BaseFormRef" v-if="dialogVisible" />
<el-form
ref="DatasetFormRef"
:rules="rules"
:model="datasetForm"
label-position="top"
require-asterisk-position="right"
>
<el-form-item label="知识库类型" required>
<el-radio-group v-model="datasetForm.type" class="card__radio" @change="radioChange">
<el-row :gutter="20">
<el-col :span="12">
<el-card
shadow="never"
class="mb-16"
:class="datasetForm.type === '0' ? 'active' : ''"
>
<el-radio value="0" size="large">
<div class="flex align-center">
<AppAvatar class="mr-8 avatar-blue" shape="square" :size="32">
<img src="@/assets/icon_document.svg" style="width: 58%" alt="" />
</AppAvatar>
<div>
<p class="mb-4">通用型</p>
<el-text type="info">可以通过上传文件或手动录入方式构建知识库</el-text>
</div>
</div>
</el-radio>
</el-card>
</el-col>
<el-col :span="12">
<el-card
shadow="never"
class="mb-16"
:class="datasetForm.type === '1' ? 'active' : ''"
>
<el-radio value="1" size="large">
<div class="flex align-center">
<AppAvatar class="mr-8 avatar-purple" shape="square" :size="32">
<img src="@/assets/icon_web.svg" style="width: 58%" alt="" />
</AppAvatar>
<div>
<p class="mb-4">Web 站点</p>
<el-text type="info">通过网站链接同步方式构建知识库 </el-text>
</div>
</div>
</el-radio>
</el-card>
</el-col>
</el-row>
</el-radio-group>
</el-form-item>
<el-form-item label="Web 根地址" prop="source_url" v-if="datasetForm.type === '1'">
<el-input
v-model="datasetForm.source_url"
placeholder="请输入 Web 根地址"
@blur="datasetForm.source_url = datasetForm.source_url.trim()"
/>
</el-form-item>
<el-form-item label="选择器" v-if="datasetForm.type === '1'">
<el-input
v-model="datasetForm.selector"
placeholder="默认为 body可输入 .classname/#idname/tagname"
@blur="datasetForm.selector = datasetForm.selector.trim()"
/>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click.prevent="dialogVisible = false" :loading="loading">
{{ $t('views.application.applicationForm.buttons.cancel') }}
</el-button>
<el-button type="primary" @click="submitValid" :loading="loading">
{{ $t('views.application.applicationForm.buttons.create') }}
</el-button>
</span>
</template>
</el-dialog>
</template>
<script setup lang="ts">
import { ref, watch, reactive } from 'vue'
import { useRouter, useRoute } from 'vue-router'
import BaseForm from './BaseForm.vue'
import datasetApi from '@/api/dataset'
import { MsgSuccess, MsgAlert } from '@/utils/message'
import useStore from '@/stores'
import { ValidType, ValidCount } from '@/enums/common'
const emit = defineEmits(['refresh'])
const { common, user } = useStore()
const router = useRouter()
const BaseFormRef = ref()
const DatasetFormRef = ref()
const loading = ref(false)
const dialogVisible = ref<boolean>(false)
const datasetForm = ref<any>({
type: '0',
source_url: '',
selector: ''
})
const rules = reactive({
source_url: [{ required: true, message: '请输入 Web 根地址', trigger: 'blur' }]
})
watch(dialogVisible, (bool) => {
if (!bool) {
datasetForm.value = {
type: '0',
source_url: '',
selector: ''
}
DatasetFormRef.value?.clearValidate()
}
})
const open = () => {
dialogVisible.value = true
}
const submitValid = () => {
if (user.isEnterprise()) {
submitHandle()
} else {
common.asyncGetValid(ValidType.Dataset, ValidCount.Dataset, loading).then(async (res: any) => {
if (res?.data) {
submitHandle()
} else {
MsgAlert('提示', '社区版最多支持 50 个知识库,如需拥有更多知识库,请升级为专业版。')
}
})
}
}
const submitHandle = async () => {
if (await BaseFormRef.value?.validate()) {
await DatasetFormRef.value.validate((valid: any) => {
if (valid) {
if (datasetForm.value.type === '0') {
const obj = {
...BaseFormRef.value.form,
type: datasetForm.value.type
}
datasetApi.postDataset(obj, loading).then((res) => {
MsgSuccess('创建成功')
router.push({ path: `/dataset/${res.data.id}/document` })
emit('refresh')
})
} else {
const obj = { ...BaseFormRef.value.form, ...datasetForm.value }
datasetApi.postWebDataset(obj, loading).then((res) => {
MsgSuccess('创建成功')
router.push({ path: `/dataset/${res.data.id}/document` })
emit('refresh')
})
}
} else {
return false
}
})
} else {
return false
}
}
function radioChange() {
datasetForm.value.source_url = ''
datasetForm.value.selector = ''
}
defineExpose({ open })
</script>
<style lang="scss" scope></style>

View File

@ -65,7 +65,7 @@
show-input
:show-input-controls="false"
:min="50"
:max="4096"
:max="100000"
/>
</div>
<div class="form-item mb-16">

View File

@ -22,7 +22,7 @@
>
<el-row :gutter="15">
<el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16">
<CardAdd title="创建知识库" @click="router.push({ path: '/dataset/create' })" />
<CardAdd title="创建知识库" @click="openCreateDialog" />
</el-col>
<template v-for="(item, index) in datasetList" :key="index">
<el-col :xs="24" :sm="12" :md="8" :lg="6" :xl="4" class="mb-16">
@ -107,17 +107,20 @@
</InfiniteScroll>
</div>
<SyncWebDialog ref="SyncWebDialogRef" @refresh="refresh" />
<CreateDatasetDialog ref="CreateDatasetDialogRef"/>
</div>
</template>
<script setup lang="ts">
import { ref, onMounted, reactive, computed } from 'vue'
import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue'
import CreateDatasetDialog from './component/CreateDatasetDialog.vue'
import datasetApi from '@/api/dataset'
import { MsgSuccess, MsgConfirm } from '@/utils/message'
import { useRouter } from 'vue-router'
import { numberFormat } from '@/utils/utils'
const router = useRouter()
const CreateDatasetDialogRef = ref()
const SyncWebDialogRef = ref()
const loading = ref(false)
const datasetList = ref<any[]>([])
@ -129,6 +132,10 @@ const paginationConfig = reactive({
const searchValue = ref('')
function openCreateDialog() {
CreateDatasetDialogRef.value.open()
}
function refresh() {
MsgSuccess('同步任务发送成功')
}

View File

@ -1,183 +0,0 @@
<template>
<el-scrollbar>
<div class="upload-document p-24">
<!-- 基本信息 -->
<BaseForm ref="BaseFormRef" v-if="isCreate" />
<el-form
v-if="isCreate"
ref="webFormRef"
:rules="rules"
:model="form"
label-position="top"
require-asterisk-position="right"
>
<el-form-item label="知识库类型" required>
<el-radio-group v-model="form.type" class="card__radio" @change="radioChange">
<el-row :gutter="20">
<el-col :span="12">
<el-card shadow="never" class="mb-16" :class="form.type === '0' ? 'active' : ''">
<el-radio value="0" size="large">
<div class="flex align-center">
<AppAvatar class="mr-8 avatar-blue" shape="square" :size="32">
<img src="@/assets/icon_document.svg" style="width: 58%" alt="" />
</AppAvatar>
<div>
<p class="mb-4">通用型</p>
<el-text type="info">可以通过上传文件或手动录入方式构建知识库</el-text>
</div>
</div>
</el-radio>
</el-card>
</el-col>
<el-col :span="12">
<el-card shadow="never" class="mb-16" :class="form.type === '1' ? 'active' : ''">
<el-radio value="1" size="large">
<div class="flex align-center">
<AppAvatar class="mr-8 avatar-purple" shape="square" :size="32">
<img src="@/assets/icon_web.svg" style="width: 58%" alt="" />
</AppAvatar>
<div>
<p class="mb-4">Web 站点</p>
<el-text type="info">通过网站链接同步方式构建知识库 </el-text>
</div>
</div>
</el-radio>
</el-card>
</el-col>
</el-row>
</el-radio-group>
</el-form-item>
<el-form-item label="Web 根地址" prop="source_url" v-if="form.type === '1'">
<el-input
v-model="form.source_url"
placeholder="请输入 Web 根地址"
@blur="form.source_url = form.source_url.trim()"
/>
</el-form-item>
<el-form-item label="选择器" v-if="form.type === '1'">
<el-input
v-model="form.selector"
placeholder="默认为 body可输入 .classname/#idname/tagname"
@blur="form.selector = form.selector.trim()"
/>
</el-form-item>
</el-form>
<!-- 上传文档 -->
<UploadComponent ref="UploadComponentRef" v-if="form.type === '0'" />
</div>
</el-scrollbar>
</template>
<script setup lang="ts">
import { ref, onMounted, reactive, watch } from 'vue'
import { useRouter, useRoute } from 'vue-router'
import BaseForm from '@/views/dataset/component/BaseForm.vue'
import UploadComponent from '@/views/dataset/component/UploadComponent.vue'
import { isAllPropertiesEmpty } from '@/utils/utils'
import datasetApi from '@/api/dataset'
import { MsgError, MsgSuccess } from '@/utils/message'
import useStore from '@/stores'
const { dataset } = useStore()
const route = useRoute()
const router = useRouter()
const {
params: { type }
} = route
const isCreate = type === 'create'
const BaseFormRef = ref()
const UploadComponentRef = ref()
const webFormRef = ref()
const loading = ref(false)
const form = ref<any>({
type: '0',
source_url: '',
selector: ''
})
const rules = reactive({
source_url: [{ required: true, message: '请输入 Web 根地址', trigger: 'blur' }]
})
watch(form.value, (value) => {
if (isAllPropertiesEmpty(value)) {
dataset.saveWebInfo(null)
} else {
dataset.saveWebInfo(value)
}
})
function radioChange() {
dataset.saveDocumentsFile([])
dataset.saveDocumentsType('')
form.value.source_url = ''
form.value.selector = ''
}
const onSubmit = async () => {
if (isCreate) {
if (form.value.type === '0') {
if ((await BaseFormRef.value?.validate()) && (await UploadComponentRef.value.validate())) {
if (UploadComponentRef.value.form.fileList.length > 50) {
MsgError('每次最多上传50个文件')
return false
} else {
/*
stores保存数据
*/
dataset.saveBaseInfo(BaseFormRef.value.form)
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
return true
}
} else {
return false
}
} else {
if (await BaseFormRef.value?.validate()) {
await webFormRef.value.validate((valid: any) => {
if (valid) {
const obj = { ...BaseFormRef.value.form, ...form.value }
datasetApi.postWebDataset(obj, loading).then((res) => {
MsgSuccess('提交成功')
dataset.saveBaseInfo(null)
dataset.saveWebInfo(null)
router.push({ path: `/dataset/${res.data.id}/document` })
})
} else {
return false
}
})
} else {
return false
}
}
} else {
if (await UploadComponentRef.value.validate()) {
/*
stores保存数据
*/
dataset.saveDocumentsType(UploadComponentRef.value.form.fileType)
dataset.saveDocumentsFile(UploadComponentRef.value.form.fileList)
return true
} else {
return false
}
}
}
onMounted(() => {})
defineExpose({
onSubmit,
loading
})
</script>
<style scoped lang="scss">
.upload-document {
width: 70%;
margin: 0 auto;
margin-bottom: 20px;
}
</style>

View File

@ -141,4 +141,13 @@ const submit = async (formEl: FormInstance) => {
defineExpose({ open })
</script>
<style lang="scss" scope></style>
<style lang="scss" scope>
.edit-mark-dialog {
.el-dialog__header.show-close {
padding-right: 15px;
}
.el-dialog__headerbtn {
top: 13px;
}
}
</style>

View File

@ -23,7 +23,7 @@
v-if="isEdit"
v-model="form.content"
placeholder="请输入分段内容"
:maxLength="4096"
:maxLength="100000"
:preview="false"
:toolbars="toolbars"
style="height: 300px"
@ -31,7 +31,7 @@
:footers="footers"
>
<template #defFooters>
<span style="margin-left: -6px">/ 4096</span>
<span style="margin-left: -6px">/ 100000</span>
</template>
</MdEditor>
<MdPreview

View File

@ -55,6 +55,32 @@
placeholder="请给基础模型设置一个名称"
/>
</el-form-item>
<el-form-item prop="model_type" :rules="base_form_data_rule.permission_type">
<template #label>
<span>权限</span>
</template>
<el-radio-group v-model="base_form_data.permission_type" class="card__radio">
<el-row :gutter="16">
<template v-for="(value, key) of PermissionType" :key="key">
<el-col :span="12">
<el-card
shadow="never"
class="mb-16"
:class="base_form_data.permission_type === key ? 'active' : ''"
>
<el-radio :value="key" size="large">
<p class="mb-4">{{ value }}</p>
<el-text type="info">
{{ PermissionDesc[key] }}
</el-text>
</el-radio>
</el-card>
</el-col>
</template>
</el-row>
</el-radio-group>
</el-form-item>
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
<template #label>
<span>模型类型</span>
@ -74,6 +100,7 @@
></el-option>
</el-select>
</el-form-item>
<el-form-item prop="model_name" :rules="base_form_data_rule.model_name">
<template #label>
<div class="flex align-center" style="display: inline-flex">
@ -135,6 +162,7 @@ import type { FormField } from '@/components/dynamics-form/type'
import DynamicsForm from '@/components/dynamics-form/index.vue'
import type { FormRules } from 'element-plus'
import { MsgSuccess } from '@/utils/message'
import { PermissionType, PermissionDesc } from '@/enums/model'
const providerValue = ref<Provider>()
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
@ -150,17 +178,18 @@ const dialogVisible = ref<boolean>(false)
const base_form_data_rule = ref<FormRules>({
name: { required: true, trigger: 'blur', message: '模型名不能为空' },
permission_type: { required: true, trigger: 'change', message: '权限不能为空' },
model_type: { required: true, trigger: 'change', message: '模型类型不能为空' },
model_name: { required: true, trigger: 'change', message: '基础模型不能为空' }
})
const base_form_data = ref<{
name: string
permission_type: string
model_type: string
model_name: string
}>({ name: '', model_type: '', model_name: '' })
}>({ name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' })
const credential_form_data = ref<Dict<any>>({})
@ -212,7 +241,7 @@ const list_base_model = (model_type: any) => {
}
const close = () => {
base_form_data.value = { name: '', model_type: '', model_name: '' }
base_form_data.value = { name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' }
credential_form_data.value = {}
model_form_field.value = []
base_model_list.value = []

View File

@ -48,6 +48,32 @@
placeholder="请给基础模型设置一个名称"
/>
</el-form-item>
<el-form-item prop="model_type" :rules="base_form_data_rule.permission_type">
<template #label>
<span>权限</span>
</template>
<el-radio-group v-model="base_form_data.permission_type" class="card__radio">
<el-row :gutter="16">
<template v-for="(value, key) of PermissionType" :key="key">
<el-col :span="12">
<el-card
shadow="never"
class="mb-16"
:class="base_form_data.permission_type === key ? 'active' : ''"
>
<el-radio :value="key" size="large">
<p class="mb-4">{{ value }}</p>
<el-text type="info">
{{ PermissionDesc[key] }}
</el-text>
</el-radio>
</el-card>
</el-col>
</template>
</el-row>
</el-radio-group>
</el-form-item>
<el-form-item prop="model_type" :rules="base_form_data_rule.model_type">
<template #label>
<span>模型类型</span>
@ -128,7 +154,7 @@ import type { FormField } from '@/components/dynamics-form/type'
import DynamicsForm from '@/components/dynamics-form/index.vue'
import type { FormRules } from 'element-plus'
import { MsgSuccess } from '@/utils/message'
import AppIcon from '@/components/icons/AppIcon.vue'
import { PermissionType, PermissionDesc } from '@/enums/model'
const providerValue = ref<Provider>()
const dynamicsFormRef = ref<InstanceType<typeof DynamicsForm>>()
@ -151,11 +177,11 @@ const base_form_data_rule = ref<FormRules>({
const base_form_data = ref<{
name: string
permission_type: string
model_type: string
model_name: string
}>({ name: '', model_type: '', model_name: '' })
}>({ name: '', model_type: '', model_name: '', permission_type: 'PRIVATE' })
const credential_form_data = ref<Dict<any>>({})
@ -204,6 +230,7 @@ const open = (provider: Provider, model: Model) => {
base_form_data.value = {
name: model.name,
permission_type: model.permission_type,
model_type: model.model_type,
model_name: model.model_name
}
@ -214,7 +241,7 @@ const open = (provider: Provider, model: Model) => {
}
const close = () => {
base_form_data.value = { name: '', model_type: '', model_name: '' }
base_form_data.value = { name: '', model_type: '', model_name: '', permission_type: '' }
dynamicsFormRef.value?.ruleFormRef?.resetFields()
credential_form_data.value = {}
model_form_field.value = []

View File

@ -1,16 +1,30 @@
<template>
<card-box :title="model.name" shadow="hover" class="model-card">
<template #header>
<div class="flex align-center">
<div class="flex">
<span style="height: 32px; width: 32px" :innerHTML="icon" class="mr-12"></span>
<auto-tooltip :content="model.name" style="max-width: 40%">
{{ model.name }}
</auto-tooltip>
<div class="flex align-center" v-if="currentModel.status === 'ERROR'">
<el-tag type="danger" class="ml-8">失败</el-tag>
<el-tooltip effect="dark" :content="errMessage" placement="top">
<el-icon class="danger ml-4" size="20"><Warning /></el-icon>
</el-tooltip>
<div class="w-full">
<div class="flex" style="height: 22px">
<auto-tooltip :content="model.name" style="max-width: 40%">
{{ model.name }}
</auto-tooltip>
<span v-if="currentModel.status === 'ERROR'">
<el-tooltip effect="dark" :content="errMessage" placement="top">
<el-icon class="danger ml-4" size="18"><Warning /></el-icon>
</el-tooltip>
</span>
<span v-if="currentModel.status === 'PAUSE_DOWNLOAD'">
<el-tooltip effect="dark" content="暂停下载" placement="top">
<el-icon class="danger ml-4" size="18"><Warning /></el-icon>
</el-tooltip>
</span>
</div>
<div class="mt-4">
<el-tag v-if="model.permission_type === 'PRIVATE'" type="danger" class="danger-tag"
>私有</el-tag
>
<el-tag v-else type="info" class="info-tag">公有</el-tag>
</div>
</div>
</div>
</template>
@ -29,18 +43,14 @@
</div>
<!-- progress -->
<div class="progress-mask" v-if="currentModel.status === 'DOWNLOAD'">
<el-progress
type="circle"
:width="56"
color="#3370FF"
:percentage="progress"
class="percentage"
>
<template #default="{ percentage }">
<span class="percentage-value">{{ percentage }}%</span>
</template>
</el-progress>
<span class="percentage-label">正在下载 <span class="dotting"></span></span>
<DownloadLoading class="percentage" />
<div class="percentage-label flex-center">
正在下载中 <span class="dotting"></span>
<el-button link type="primary" class="ml-16" @click.stop="cancelDownload"
>取消下载</el-button
>
</div>
</div>
<template #mouseEnter>
@ -48,7 +58,13 @@
<el-tooltip effect="dark" content="修改" placement="top">
<el-button text @click.stop="openEditModel">
<el-icon>
<component :is="currentModel.status === 'ERROR' ? 'RefreshRight' : 'EditPen'" />
<component
:is="
currentModel.status === 'ERROR' || currentModel.status === 'PAUSE_DOWNLOAD'
? 'RefreshRight'
: 'EditPen'
"
/>
</el-icon>
</el-button>
</el-tooltip>
@ -68,6 +84,7 @@ import type { Provider, Model } from '@/api/type/model'
import ModelApi from '@/api/model'
import { computed, ref, onMounted, onBeforeUnmount } from 'vue'
import EditModel from '@/views/template/component/EditModel.vue'
import DownloadLoading from '@/components/loading/DownloadLoading.vue'
import { MsgConfirm } from '@/utils/message'
const props = defineProps<{
@ -94,27 +111,6 @@ const errMessage = computed(() => {
}
return ''
})
const progress = computed(() => {
if (currentModel.value) {
const down_model_chunk = currentModel.value.meta['down_model_chunk']
if (down_model_chunk) {
const maxObj = down_model_chunk
.filter((chunk: any) => chunk.index > 1)
.reduce(
(prev: any, current: any) => {
return (prev.index || 0) > (current.index || 0) ? prev : current
},
{ progress: 0 }
)
if (maxObj) {
return parseFloat(maxObj.progress?.toFixed(1))
}
return 0
}
return 0
}
return 0
})
const emit = defineEmits(['change', 'update:model'])
const eidtModelRef = ref<InstanceType<typeof EditModel>>()
let interval: any
@ -130,6 +126,13 @@ const deleteModel = () => {
})
.catch(() => {})
}
const cancelDownload = () => {
ModelApi.pauseDownload(props.model.id).then(() => {
downModel.value = undefined
emit('change')
})
}
const openEditModel = () => {
const provider = props.provider_list.find((p) => p.provider === props.model.provider)
if (provider) {
@ -197,21 +200,21 @@ onBeforeUnmount(() => {
z-index: 99;
text-align: center;
.percentage {
top: 50%;
transform: translateY(-65%);
margin-top: 55px;
margin-bottom: 16px;
}
.percentage-value {
display: block;
font-size: 12px;
color: var(--el-color-primary);
}
// .percentage-value {
// display: flex;
// font-size: 13px;
// align-items: center;
// color: var(--app-text-color-secondary);
// }
.percentage-label {
display: block;
margin-top: 45px;
margin-top: 50px;
margin-left: 10px;
font-size: 12px;
color: var(--el-color-primary);
font-size: 13px;
color: var(--app-text-color-secondary);
}
}
}

View File

@ -145,10 +145,7 @@ function createUser() {
title.value = '创建用户'
UserDialogRef.value.open()
} else {
MsgAlert(
'提示',
'社区版最多支持 2 个用户如需拥有更多用户请联系我们https://fit2cloud.com/)。'
)
MsgAlert('提示', '社区版最多支持 2 个用户,如需拥有更多用户,请升级为专业版。')
}
})
}