mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 10:12:51 +00:00
feat: 支持向量模型
This commit is contained in:
parent
e600b91de2
commit
75b9b17e2e
|
|
@ -39,6 +39,7 @@ from common.util.file_util import get_file_content
|
|||
from common.util.lock import try_lock, un_lock
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||
from setting.models import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
|
@ -533,9 +534,10 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||
raise AppApiException(500, "文档id不正确")
|
||||
|
||||
@staticmethod
|
||||
def post_embedding_paragraph(chat_record, paragraph_id):
|
||||
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
# 发送向量化事件
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id)
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
|
||||
return chat_record
|
||||
|
||||
@post(post_function=post_embedding_paragraph)
|
||||
|
|
@ -573,7 +575,7 @@ class ChatRecordSerializer(serializers.Serializer):
|
|||
chat_record.improve_paragraph_id_list.append(paragraph.id)
|
||||
# 添加标注
|
||||
chat_record.save()
|
||||
return ChatRecordSerializerModel(chat_record).data, paragraph.id
|
||||
return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id
|
||||
|
||||
class Operate(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
||||
|
|
|
|||
|
|
@ -14,14 +14,14 @@ embedding_thread_pool = ThreadPoolExecutor(3)
|
|||
|
||||
|
||||
def poxy(poxy_function):
|
||||
def inner(args):
|
||||
work_thread_pool.submit(poxy_function, args)
|
||||
def inner(args, **keywords):
|
||||
work_thread_pool.submit(poxy_function, args, **keywords)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def embedding_poxy(poxy_function):
|
||||
def inner(args):
|
||||
embedding_thread_pool.submit(poxy_function, args)
|
||||
def inner(args, **keywords):
|
||||
embedding_thread_pool.submit(poxy_function, args, **keywords)
|
||||
|
||||
return inner
|
||||
|
|
|
|||
|
|
@ -85,8 +85,8 @@ class ListenerManagement:
|
|||
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")
|
||||
|
||||
@staticmethod
|
||||
def embedding_by_problem(args):
|
||||
VectorStore.get_embedding_vector().save(**args)
|
||||
def embedding_by_problem(args, embedding_model: Embeddings):
|
||||
VectorStore.get_embedding_vector().save(**args, embedding=embedding_model)
|
||||
|
||||
@staticmethod
|
||||
@embedding_poxy
|
||||
|
|
@ -165,7 +165,7 @@ class ListenerManagement:
|
|||
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
|
||||
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
|
||||
for document in document_list:
|
||||
ListenerManagement.embedding_by_document(document.id, embedding_model)
|
||||
ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model)
|
||||
except Exception as e:
|
||||
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -145,5 +145,5 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
|
|||
|
||||
|
||||
def get_embedding_model_by_dataset_id(dataset_id: str):
|
||||
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id)
|
||||
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
|
||||
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
|
||||
return EmbeddingModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))
|
||||
|
|
|
|||
|
|
@ -745,7 +745,8 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
def re_embedding(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'))
|
||||
model = get_embedding_model_by_dataset_id(self.data.get('id'))
|
||||
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model)
|
||||
|
||||
def list_application(self, with_valid=True):
|
||||
if with_valid:
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ from common.util.file_util import get_file_content
|
|||
from common.util.fork import Fork
|
||||
from common.util.split_model import get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
|
||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage
|
||||
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
|
||||
get_embedding_model_by_dataset_id
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
|
@ -392,7 +393,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
problem_paragraph_mapping_list) > 0 else None
|
||||
# 向量化
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
||||
model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id)
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||
else:
|
||||
document.status = Status.error
|
||||
document.save()
|
||||
|
|
@ -405,6 +407,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
class Operate(ApiMixin, serializers.Serializer):
|
||||
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
|
||||
"文档id"))
|
||||
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id"))
|
||||
|
||||
@staticmethod
|
||||
def get_request_params_api():
|
||||
|
|
@ -530,7 +533,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
document_id = self.data.get("document_id")
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
||||
model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id'))
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||
|
||||
@transaction.atomic
|
||||
def delete(self):
|
||||
|
|
@ -599,8 +603,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
return True
|
||||
|
||||
@staticmethod
|
||||
def post_embedding(result, document_id):
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
||||
def post_embedding(result, document_id, dataset_id):
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -646,7 +651,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
document_id = str(document_model.id)
|
||||
return DocumentSerializers.Operate(
|
||||
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||
with_valid=True), document_id
|
||||
with_valid=True), document_id, dataset_id
|
||||
|
||||
@staticmethod
|
||||
def get_sync_handler(dataset_id):
|
||||
|
|
@ -803,9 +808,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
|
||||
|
||||
@staticmethod
|
||||
def post_embedding(document_list):
|
||||
def post_embedding(document_list, dataset_id):
|
||||
for document_dict in document_list:
|
||||
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'))
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model)
|
||||
return document_list
|
||||
|
||||
@post(post_function=post_embedding)
|
||||
|
|
@ -846,7 +852,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
return [],
|
||||
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
||||
return native_search(query_set, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')),
|
||||
with_search_one=False), dataset_id
|
||||
|
||||
@staticmethod
|
||||
def _batch_sync(document_id_list: List[str]):
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from common.util.common import post
|
|||
from common.util.field_message import ErrMessage
|
||||
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
|
||||
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
|
||||
ProblemParagraphManage
|
||||
ProblemParagraphManage, get_embedding_model_by_dataset_id
|
||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
|
||||
from embedding.models import SourceType
|
||||
|
||||
|
|
@ -132,6 +132,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
paragraph_id=self.data.get('paragraph_id'),
|
||||
dataset_id=self.data.get('dataset_id'))
|
||||
problem_paragraph_mapping.save()
|
||||
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||
'is_active': True,
|
||||
|
|
@ -140,7 +141,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
})
|
||||
}, embedding_model=model)
|
||||
|
||||
return ProblemSerializers.Operate(
|
||||
data={'dataset_id': self.data.get('dataset_id'),
|
||||
|
|
@ -227,6 +228,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
problem_id=problem.id)
|
||||
problem_paragraph_mapping.save()
|
||||
if with_embedding:
|
||||
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
|
||||
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
|
||||
'is_active': True,
|
||||
'source_type': SourceType.PROBLEM,
|
||||
|
|
@ -234,7 +236,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
'document_id': self.data.get('document_id'),
|
||||
'paragraph_id': self.data.get('paragraph_id'),
|
||||
'dataset_id': self.data.get('dataset_id'),
|
||||
})
|
||||
}, embedding_model=model)
|
||||
|
||||
def un_association(self, with_valid=True):
|
||||
if with_valid:
|
||||
|
|
@ -454,13 +456,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
raise AppApiException(500, "段落id不存在")
|
||||
|
||||
@staticmethod
|
||||
def post_embedding(paragraph, instance):
|
||||
def post_embedding(paragraph, instance, dataset_id):
|
||||
if 'is_active' in instance and instance.get('is_active') is not None:
|
||||
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
||||
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
||||
s.send(paragraph.get('id'))
|
||||
else:
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'))
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model)
|
||||
return paragraph
|
||||
|
||||
@post(post_embedding)
|
||||
|
|
@ -508,7 +511,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
|
||||
_paragraph.save()
|
||||
update_document_char_length(self.data.get('document_id'))
|
||||
return self.one(), instance
|
||||
return self.one(), instance, self.data.get('dataset_id')
|
||||
|
||||
def get_problem_list(self):
|
||||
ProblemParagraphMapping(ProblemParagraphMapping)
|
||||
|
|
@ -582,7 +585,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
# 修改长度
|
||||
update_document_char_length(document_id)
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
|
||||
model = get_embedding_model_by_dataset_id(dataset_id)
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model)
|
||||
return ParagraphSerializers.Operate(
|
||||
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||
with_valid=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue