feat: 问答时,同步存入日志,优化向量化执行逻辑,修改model下载目录

This commit is contained in:
shaohuzhang1 2023-12-21 16:55:11 +08:00
parent b6f7537c2b
commit 2c6135a929
10 changed files with 120 additions and 39 deletions

View File

@ -45,7 +45,9 @@ class ChatMessage:
source_id: str,
answer: str,
message_tokens: int,
answer_token: int):
answer_token: int,
chat_model=None,
chat_message=None):
self.id = id
self.problem = problem
self.title = title
@ -59,6 +61,8 @@ class ChatMessage:
self.answer = answer
self.message_tokens = message_tokens
self.answer_token = answer_token
self.chat_model = chat_model
self.chat_message = chat_message
def get_chat_message(self):
return MessageManagement.get_message(self.problem, self.paragraph, self.problem)
@ -85,10 +89,13 @@ class ChatInfo:
def append_chat_message(self, chat_message: ChatMessage):
self.chat_message_list.append(chat_message)
if self.application_id is not None:
# 插入数据库
event.ListenerChatMessage.record_chat_message_signal.send(
event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id,
chat_message)
)
# 异步更新token
event.ListenerChatMessage.update_chat_message_token_signal.send(chat_message)
def get_context_message(self):
start_index = len(self.chat_message_list) - self.dialogue_number
@ -176,8 +183,10 @@ class ChatMessageSerializer(serializers.Serializer):
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
paragraph_id,
source_type,
source_id, all_text, chat_model.get_num_tokens_from_messages(chat_message),
chat_model.get_num_tokens(all_text)))
source_id, all_text,
0,
0,
chat_message=chat_message, chat_model=chat_model))
# 重新设置缓存
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)

View File

@ -61,7 +61,7 @@ class ChatSerializers(serializers.Serializer):
query_dict = {'application_id': self.data.get("application_id"), 'create_time__gte': end_time}
if 'abstract' in self.data and self.data.get('abstract') is not None:
query_dict['abstract'] = self.data.get('abstract')
return QuerySet(Chat).filter(**query_dict)
return QuerySet(Chat).filter(**query_dict).order_by("-create_time")
def list(self, with_valid=True):
if with_valid:

View File

@ -6,6 +6,7 @@
@date2023/10/20 14:01
@desc:
"""
import logging
from blinker import signal
from django.db.models import QuerySet
@ -25,9 +26,9 @@ class RecordChatMessageArgs:
class ListenerChatMessage:
record_chat_message_signal = signal("record_chat_message")
update_chat_message_token_signal = signal("update_chat_message_token")
@staticmethod
@poxy
def record_chat_message(args: RecordChatMessageArgs):
if not QuerySet(Chat).filter(id=args.chat_id).exists():
Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save()
@ -49,6 +50,18 @@ class ListenerChatMessage:
except Exception as e:
print(e)
@staticmethod
@poxy
def update_token(chat_message: ChatMessage):
if chat_message.chat_model is not None:
logging.getLogger("max_kb").info("开始更新token")
message_token = chat_message.chat_model.get_num_tokens_from_messages(chat_message.chat_message)
answer_token = chat_message.chat_model.get_num_tokens(chat_message.answer)
# 修改token数量
QuerySet(ChatRecord).filter(id=chat_message.id).update(
**{'message_tokens': message_token, 'answer_tokens': answer_token})
def run(self):
# 记录会话
ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message)
ListenerChatMessage.update_chat_message_token_signal.connect(self.update_token)

View File

@ -109,9 +109,10 @@ class ListenerManagement:
:param dataset_id: 知识库id
:return: None
"""
max_kb.info(f"向量化数据集{dataset_id}")
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
try:
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
for document in document_list:
ListenerManagement.embedding_by_document(document.id)
except Exception as e:

View File

@ -43,3 +43,14 @@ def get_exec_method(clazz_: str, method_: str):
package = ".".join([clazz_split[index] for index in range(len(clazz_split) - 1)])
package_model = importlib.import_module(package)
return getattr(getattr(package_model, clazz_name), method_)
def post(post_function):
def inner(func):
def run(*args, **kwargs):
result = func(*args, **kwargs)
return post_function(*result)
return run
return inner

View File

@ -23,6 +23,7 @@ from common.db.sql_execute import select_list
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.file_util import get_file_content
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
@ -207,6 +208,13 @@ class DataSetSerializers(serializers.ModelSerializer):
super().is_valid(raise_exception=True)
return True
@staticmethod
def post_embedding_dataset(document_list, dataset_id):
# 发送向量化事件
ListenerManagement.embedding_by_dataset_signal.send(dataset_id)
return document_list
@post(post_function=post_embedding_dataset)
@transaction.atomic
def save(self, user: User):
dataset_id = uuid.uuid1()
@ -234,11 +242,11 @@ class DataSetSerializers(serializers.ModelSerializer):
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 发送向量化事件
ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id))
# 响应数据
return {**DataSetSerializers(dataset).data,
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)}
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(
with_valid=True)}, dataset_id
@staticmethod
def get_response_body_api():

View File

@ -21,6 +21,7 @@ from common.db.search import native_search, native_page_search
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.file_util import get_file_content
from common.util.split_model import SplitModel, get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
@ -207,7 +208,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
raise AppApiException(10000, "知识库id不存在")
return True
def save(self, instance: Dict, with_valid=False, with_embedding=True, **kwargs):
@staticmethod
def post_embedding(result, document_id):
ListenerManagement.embedding_by_document_signal.send(document_id)
return result
@post(post_function=post_embedding)
@transaction.atomic
def save(self, instance: Dict, with_valid=False, **kwargs):
if with_valid:
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)
@ -222,11 +230,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
if with_embedding:
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
document_id = str(document_model.id)
return DocumentSerializers.Operate(
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one(
with_valid=True)
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
with_valid=True), document_id
@staticmethod
def get_document_paragraph_model(dataset_id, instance: Dict):
@ -333,14 +340,43 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
@staticmethod
def post_embedding(document_list):
for document_dict in document_list:
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'))
return document_list
@post(post_function=post_embedding)
@transaction.atomic
def batch_save(self, instance_list: List[Dict], with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True)
create_data = {'dataset_id': self.data.get("dataset_id")}
return [DocumentSerializers.Create(data=create_data).save(instance,
with_valid=True)
for instance in instance_list]
dataset_id = self.data.get("dataset_id")
document_model_list = []
paragraph_model_list = []
problem_model_list = []
# 插入文档
for document in instance_list:
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
document)
document_model_list.append(document_paragraph_dict_model.get('document'))
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
paragraph_model_list.append(paragraph)
for problem in document_paragraph_dict_model.get('problem_model_list'):
problem_model_list.append(problem)
# 插入文档
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
# 批量插入段落
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
# 批量插入问题
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
# 查询文档
query_set = QuerySet(model=Document)
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
return native_search(query_set, select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):

View File

@ -19,6 +19,7 @@ from common.db.search import page_search
from common.event.listener_manage import ListenerManagement
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from dataset.models import Paragraph, Problem, Document
from dataset.serializers.common_serializers import update_document_char_length
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
@ -82,6 +83,17 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
raise AppApiException(500, "段落id不存在")
@staticmethod
def post_embedding(paragraph, instance):
if 'is_active' in instance and instance.get('is_active') is not None:
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
s.send(paragraph.get('id'))
else:
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'))
return paragraph
@post(post_embedding)
@transaction.atomic
def edit(self, instance: Dict):
self.is_valid()
@ -125,11 +137,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
_paragraph.save()
update_document_char_length(self.data.get('document_id'))
if 'is_active' in instance and instance.get('is_active') is not None:
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
s.send(self.data.get('paragraph_id'))
return self.one()
return self.one(), instance
def get_problem_list(self):
return [ProblemSerializer(problem).data for problem in

View File

@ -68,17 +68,12 @@ class BaseVectorStore(ABC):
:param trample_num 点踩数量
:return: bool
"""
# 获取锁
lock.acquire()
try:
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler()
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, star_num,
trample_num, embedding)
finally:
# 释放锁
lock.release()
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler()
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, star_num,
trample_num, embedding)
def batch_save(self, data_list: List[Dict], embedding=None):
# 获取锁

View File

@ -16,35 +16,35 @@ prefix_dir = "/opt/maxkb/model"
model_config = [
{
'download_params': {
'cache_dir': os.path.join(prefix_dir, 'base'),
'cache_dir': os.path.join(prefix_dir, 'base/hub'),
'pretrained_model_name_or_path': 'gpt2'
},
'download_function': GPT2TokenizerFast.from_pretrained
},
{
'download_params': {
'cache_dir': os.path.join(prefix_dir, 'base'),
'cache_dir': os.path.join(prefix_dir, 'base/hub'),
'pretrained_model_name_or_path': 'gpt2-medium'
},
'download_function': GPT2TokenizerFast.from_pretrained
},
{
'download_params': {
'cache_dir': os.path.join(prefix_dir, 'base'),
'cache_dir': os.path.join(prefix_dir, 'base/hub'),
'pretrained_model_name_or_path': 'gpt2-large'
},
'download_function': GPT2TokenizerFast.from_pretrained
},
{
'download_params': {
'cache_dir': os.path.join(prefix_dir, 'base'),
'cache_dir': os.path.join(prefix_dir, 'base/hub'),
'pretrained_model_name_or_path': 'gpt2-xl'
},
'download_function': GPT2TokenizerFast.from_pretrained
},
{
'download_params': {
'cache_dir': os.path.join(prefix_dir, 'base'),
'cache_dir': os.path.join(prefix_dir, 'base/hub'),
'pretrained_model_name_or_path': 'distilgpt2'
},
'download_function': GPT2TokenizerFast.from_pretrained