feat: 支持向量模型

This commit is contained in:
shaohuzhang1 2024-07-18 10:26:16 +08:00
parent e600b91de2
commit 75b9b17e2e
7 changed files with 43 additions and 29 deletions

View File

@ -39,6 +39,7 @@ from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import rsa_long_decrypt
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.common_serializers import get_embedding_model_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
@ -533,9 +534,10 @@ class ChatRecordSerializer(serializers.Serializer):
raise AppApiException(500, "文档id不正确")
@staticmethod
def post_embedding_paragraph(chat_record, paragraph_id):
def post_embedding_paragraph(chat_record, paragraph_id, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
# 发送向量化事件
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id)
ListenerManagement.embedding_by_paragraph_signal.send(paragraph_id, embedding_model=model)
return chat_record
@post(post_function=post_embedding_paragraph)
@ -573,7 +575,7 @@ class ChatRecordSerializer(serializers.Serializer):
chat_record.improve_paragraph_id_list.append(paragraph.id)
# 添加标注
chat_record.save()
return ChatRecordSerializerModel(chat_record).data, paragraph.id
return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id
class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))

View File

@ -14,14 +14,14 @@ embedding_thread_pool = ThreadPoolExecutor(3)
def poxy(poxy_function):
def inner(args):
work_thread_pool.submit(poxy_function, args)
def inner(args, **keywords):
work_thread_pool.submit(poxy_function, args, **keywords)
return inner
def embedding_poxy(poxy_function):
def inner(args):
embedding_thread_pool.submit(poxy_function, args)
def inner(args, **keywords):
embedding_thread_pool.submit(poxy_function, args, **keywords)
return inner

View File

@ -85,8 +85,8 @@ class ListenerManagement:
delete_embedding_by_dataset_id_list_signal = signal("delete_embedding_by_dataset_id_list")
@staticmethod
def embedding_by_problem(args):
VectorStore.get_embedding_vector().save(**args)
def embedding_by_problem(args, embedding_model: Embeddings):
VectorStore.get_embedding_vector().save(**args, embedding=embedding_model)
@staticmethod
@embedding_poxy
@ -165,7 +165,7 @@ class ListenerManagement:
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
for document in document_list:
ListenerManagement.embedding_by_document(document.id, embedding_model)
ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model)
except Exception as e:
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
finally:

View File

@ -145,5 +145,5 @@ def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
def get_embedding_model_by_dataset_id(dataset_id: str):
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id)
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first()
return EmbeddingModelManage.get_model(dataset_id, lambda _id: get_model(dataset.embedding_mode))

View File

@ -745,7 +745,8 @@ class DataSetSerializers(serializers.ModelSerializer):
def re_embedding(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'))
model = get_embedding_model_by_dataset_id(self.data.get('id'))
ListenerManagement.embedding_by_dataset_signal.send(self.data.get('id'), embedding_model=model)
def list_application(self, with_valid=True):
if with_valid:

View File

@ -41,7 +41,8 @@ from common.util.file_util import get_file_content
from common.util.fork import Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage
from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id
from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer
from smartdoc.conf import PROJECT_DIR
@ -392,7 +393,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
problem_paragraph_mapping_list) > 0 else None
# 向量化
if with_embedding:
ListenerManagement.embedding_by_document_signal.send(document_id)
model = get_embedding_model_by_dataset_id(dataset_id=document.dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
else:
document.status = Status.error
document.save()
@ -405,6 +407,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
class Operate(ApiMixin, serializers.Serializer):
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
"文档id"))
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id"))
@staticmethod
def get_request_params_api():
@ -530,7 +533,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
if with_valid:
self.is_valid(raise_exception=True)
document_id = self.data.get("document_id")
ListenerManagement.embedding_by_document_signal.send(document_id)
model = get_embedding_model_by_dataset_id(dataset_id=self.data.get('dataset_id'))
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
@transaction.atomic
def delete(self):
@ -599,8 +603,9 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return True
@staticmethod
def post_embedding(result, document_id):
ListenerManagement.embedding_by_document_signal.send(document_id)
def post_embedding(result, document_id, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_id, embedding_model=model)
return result
@staticmethod
@ -646,7 +651,7 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
document_id = str(document_model.id)
return DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True), document_id
with_valid=True), document_id, dataset_id
@staticmethod
def get_sync_handler(dataset_id):
@ -803,9 +808,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
@staticmethod
def post_embedding(document_list):
def post_embedding(document_list, dataset_id):
for document_dict in document_list:
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'))
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'), embedding_model=model)
return document_list
@post(post_function=post_embedding)
@ -846,7 +852,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
return [],
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')),
with_search_one=False), dataset_id
@staticmethod
def _batch_sync(document_id_list: List[str]):

View File

@ -22,7 +22,7 @@ from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \
ProblemParagraphManage
ProblemParagraphManage, get_embedding_model_by_dataset_id
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType
@ -132,6 +132,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
paragraph_id=self.data.get('paragraph_id'),
dataset_id=self.data.get('dataset_id'))
problem_paragraph_mapping.save()
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
if with_embedding:
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
@ -140,7 +141,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
})
}, embedding_model=model)
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'),
@ -227,6 +228,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
problem_id=problem.id)
problem_paragraph_mapping.save()
if with_embedding:
model = get_embedding_model_by_dataset_id(self.data.get('dataset_id'))
ListenerManagement.embedding_by_problem_signal.send({'text': problem.content,
'is_active': True,
'source_type': SourceType.PROBLEM,
@ -234,7 +236,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
})
}, embedding_model=model)
def un_association(self, with_valid=True):
if with_valid:
@ -454,13 +456,14 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
raise AppApiException(500, "段落id不存在")
@staticmethod
def post_embedding(paragraph, instance):
def post_embedding(paragraph, instance, dataset_id):
if 'is_active' in instance and instance.get('is_active') is not None:
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
s.send(paragraph.get('id'))
else:
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'))
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'), embedding_model=model)
return paragraph
@post(post_embedding)
@ -508,7 +511,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
_paragraph.save()
update_document_char_length(self.data.get('document_id'))
return self.one(), instance
return self.one(), instance, self.data.get('dataset_id')
def get_problem_list(self):
ProblemParagraphMapping(ProblemParagraphMapping)
@ -582,7 +585,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
# 修改长度
update_document_char_length(document_id)
if with_embedding:
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id))
model = get_embedding_model_by_dataset_id(dataset_id)
ListenerManagement.embedding_by_paragraph_signal.send(str(paragraph.id), embedding_model=model)
return ParagraphSerializers.Operate(
data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True)