mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 问答时,同步存入日志,优化向量化执行逻辑,修改model下载目录
This commit is contained in:
parent
b6f7537c2b
commit
2c6135a929
|
|
@ -45,7 +45,9 @@ class ChatMessage:
|
|||
source_id: str,
|
||||
answer: str,
|
||||
message_tokens: int,
|
||||
answer_token: int):
|
||||
answer_token: int,
|
||||
chat_model=None,
|
||||
chat_message=None):
|
||||
self.id = id
|
||||
self.problem = problem
|
||||
self.title = title
|
||||
|
|
@ -59,6 +61,8 @@ class ChatMessage:
|
|||
self.answer = answer
|
||||
self.message_tokens = message_tokens
|
||||
self.answer_token = answer_token
|
||||
self.chat_model = chat_model
|
||||
self.chat_message = chat_message
|
||||
|
||||
def get_chat_message(self):
|
||||
return MessageManagement.get_message(self.problem, self.paragraph, self.problem)
|
||||
|
|
@ -85,10 +89,13 @@ class ChatInfo:
|
|||
def append_chat_message(self, chat_message: ChatMessage):
|
||||
self.chat_message_list.append(chat_message)
|
||||
if self.application_id is not None:
|
||||
# 插入数据库
|
||||
event.ListenerChatMessage.record_chat_message_signal.send(
|
||||
event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id,
|
||||
chat_message)
|
||||
)
|
||||
# 异步更新token
|
||||
event.ListenerChatMessage.update_chat_message_token_signal.send(chat_message)
|
||||
|
||||
def get_context_message(self):
|
||||
start_index = len(self.chat_message_list) - self.dialogue_number
|
||||
|
|
@ -176,8 +183,10 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
|
||||
paragraph_id,
|
||||
source_type,
|
||||
source_id, all_text, chat_model.get_num_tokens_from_messages(chat_message),
|
||||
chat_model.get_num_tokens(all_text)))
|
||||
source_id, all_text,
|
||||
0,
|
||||
0,
|
||||
chat_message=chat_message, chat_model=chat_model))
|
||||
# 重新设置缓存
|
||||
chat_cache.set(chat_id,
|
||||
chat_info, timeout=60 * 30)
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ class ChatSerializers(serializers.Serializer):
|
|||
query_dict = {'application_id': self.data.get("application_id"), 'create_time__gte': end_time}
|
||||
if 'abstract' in self.data and self.data.get('abstract') is not None:
|
||||
query_dict['abstract'] = self.data.get('abstract')
|
||||
return QuerySet(Chat).filter(**query_dict)
|
||||
return QuerySet(Chat).filter(**query_dict).order_by("-create_time")
|
||||
|
||||
def list(self, with_valid=True):
|
||||
if with_valid:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
@date:2023/10/20 14:01
|
||||
@desc:
|
||||
"""
|
||||
import logging
|
||||
|
||||
from blinker import signal
|
||||
from django.db.models import QuerySet
|
||||
|
|
@ -25,9 +26,9 @@ class RecordChatMessageArgs:
|
|||
|
||||
class ListenerChatMessage:
|
||||
record_chat_message_signal = signal("record_chat_message")
|
||||
update_chat_message_token_signal = signal("update_chat_message_token")
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def record_chat_message(args: RecordChatMessageArgs):
|
||||
if not QuerySet(Chat).filter(id=args.chat_id).exists():
|
||||
Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save()
|
||||
|
|
@ -49,6 +50,18 @@ class ListenerChatMessage:
|
|||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
@staticmethod
|
||||
@poxy
|
||||
def update_token(chat_message: ChatMessage):
|
||||
if chat_message.chat_model is not None:
|
||||
logging.getLogger("max_kb").info("开始更新token")
|
||||
message_token = chat_message.chat_model.get_num_tokens_from_messages(chat_message.chat_message)
|
||||
answer_token = chat_message.chat_model.get_num_tokens(chat_message.answer)
|
||||
# 修改token数量
|
||||
QuerySet(ChatRecord).filter(id=chat_message.id).update(
|
||||
**{'message_tokens': message_token, 'answer_tokens': answer_token})
|
||||
|
||||
def run(self):
|
||||
# 记录会话
|
||||
ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message)
|
||||
ListenerChatMessage.update_chat_message_token_signal.connect(self.update_token)
|
||||
|
|
|
|||
|
|
@ -109,9 +109,10 @@ class ListenerManagement:
|
|||
:param dataset_id: 知识库id
|
||||
:return: None
|
||||
"""
|
||||
max_kb.info(f"向量化数据集{dataset_id}")
|
||||
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
|
||||
try:
|
||||
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
|
||||
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
|
||||
for document in document_list:
|
||||
ListenerManagement.embedding_by_document(document.id)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -43,3 +43,14 @@ def get_exec_method(clazz_: str, method_: str):
|
|||
package = ".".join([clazz_split[index] for index in range(len(clazz_split) - 1)])
|
||||
package_model = importlib.import_module(package)
|
||||
return getattr(getattr(package_model, clazz_name), method_)
|
||||
|
||||
|
||||
def post(post_function):
|
||||
def inner(func):
|
||||
def run(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
return post_function(*result)
|
||||
|
||||
return run
|
||||
|
||||
return inner
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from common.db.sql_execute import select_list
|
|||
from common.event.listener_manage import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from common.util.file_util import get_file_content
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
|
||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
|
||||
|
|
@ -207,6 +208,13 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
super().is_valid(raise_exception=True)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def post_embedding_dataset(document_list, dataset_id):
|
||||
# 发送向量化事件
|
||||
ListenerManagement.embedding_by_dataset_signal.send(dataset_id)
|
||||
return document_list
|
||||
|
||||
@post(post_function=post_embedding_dataset)
|
||||
@transaction.atomic
|
||||
def save(self, user: User):
|
||||
dataset_id = uuid.uuid1()
|
||||
|
|
@ -234,11 +242,11 @@ class DataSetSerializers(serializers.ModelSerializer):
|
|||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 发送向量化事件
|
||||
ListenerManagement.embedding_by_dataset_signal.send(str(dataset.id))
|
||||
|
||||
# 响应数据
|
||||
return {**DataSetSerializers(dataset).data,
|
||||
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(with_valid=True)}
|
||||
'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list(
|
||||
with_valid=True)}, dataset_id
|
||||
|
||||
@staticmethod
|
||||
def get_response_body_api():
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from common.db.search import native_search, native_page_search
|
|||
from common.event.listener_manage import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.split_model import SplitModel, get_split_model
|
||||
from dataset.models.data_set import DataSet, Document, Paragraph, Problem
|
||||
|
|
@ -207,7 +208,14 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
raise AppApiException(10000, "知识库id不存在")
|
||||
return True
|
||||
|
||||
def save(self, instance: Dict, with_valid=False, with_embedding=True, **kwargs):
|
||||
@staticmethod
|
||||
def post_embedding(result, document_id):
|
||||
ListenerManagement.embedding_by_document_signal.send(document_id)
|
||||
return result
|
||||
|
||||
@post(post_function=post_embedding)
|
||||
@transaction.atomic
|
||||
def save(self, instance: Dict, with_valid=False, **kwargs):
|
||||
if with_valid:
|
||||
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
|
|
@ -222,11 +230,10 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
if with_embedding:
|
||||
ListenerManagement.embedding_by_document_signal.send(str(document_model.id))
|
||||
document_id = str(document_model.id)
|
||||
return DocumentSerializers.Operate(
|
||||
data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).one(
|
||||
with_valid=True)
|
||||
data={'dataset_id': dataset_id, 'document_id': document_id}).one(
|
||||
with_valid=True), document_id
|
||||
|
||||
@staticmethod
|
||||
def get_document_paragraph_model(dataset_id, instance: Dict):
|
||||
|
|
@ -333,14 +340,43 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
def get_request_body_api():
|
||||
return openapi.Schema(type=openapi.TYPE_ARRAY, items=DocumentSerializers.Create.get_request_body_api())
|
||||
|
||||
@staticmethod
|
||||
def post_embedding(document_list):
|
||||
for document_dict in document_list:
|
||||
ListenerManagement.embedding_by_document_signal.send(document_dict.get('id'))
|
||||
return document_list
|
||||
|
||||
@post(post_function=post_embedding)
|
||||
@transaction.atomic
|
||||
def batch_save(self, instance_list: List[Dict], with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
DocumentInstanceSerializer(many=True, data=instance_list).is_valid(raise_exception=True)
|
||||
create_data = {'dataset_id': self.data.get("dataset_id")}
|
||||
return [DocumentSerializers.Create(data=create_data).save(instance,
|
||||
with_valid=True)
|
||||
for instance in instance_list]
|
||||
dataset_id = self.data.get("dataset_id")
|
||||
document_model_list = []
|
||||
paragraph_model_list = []
|
||||
problem_model_list = []
|
||||
# 插入文档
|
||||
for document in instance_list:
|
||||
document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id,
|
||||
document)
|
||||
document_model_list.append(document_paragraph_dict_model.get('document'))
|
||||
for paragraph in document_paragraph_dict_model.get('paragraph_model_list'):
|
||||
paragraph_model_list.append(paragraph)
|
||||
for problem in document_paragraph_dict_model.get('problem_model_list'):
|
||||
problem_model_list.append(problem)
|
||||
|
||||
# 插入文档
|
||||
QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None
|
||||
# 批量插入段落
|
||||
QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None
|
||||
# 批量插入问题
|
||||
QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None
|
||||
# 查询文档
|
||||
query_set = QuerySet(model=Document)
|
||||
query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]})
|
||||
return native_search(query_set, select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=False),
|
||||
|
||||
|
||||
def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int):
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from common.db.search import page_search
|
|||
from common.event.listener_manage import ListenerManagement
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
from common.util.common import post
|
||||
from dataset.models import Paragraph, Problem, Document
|
||||
from dataset.serializers.common_serializers import update_document_char_length
|
||||
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer
|
||||
|
|
@ -82,6 +83,17 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists():
|
||||
raise AppApiException(500, "段落id不存在")
|
||||
|
||||
@staticmethod
|
||||
def post_embedding(paragraph, instance):
|
||||
if 'is_active' in instance and instance.get('is_active') is not None:
|
||||
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
||||
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
||||
s.send(paragraph.get('id'))
|
||||
else:
|
||||
ListenerManagement.embedding_by_paragraph_signal.send(paragraph.get('id'))
|
||||
return paragraph
|
||||
|
||||
@post(post_embedding)
|
||||
@transaction.atomic
|
||||
def edit(self, instance: Dict):
|
||||
self.is_valid()
|
||||
|
|
@ -125,11 +137,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
|
|||
|
||||
_paragraph.save()
|
||||
update_document_char_length(self.data.get('document_id'))
|
||||
if 'is_active' in instance and instance.get('is_active') is not None:
|
||||
s = (ListenerManagement.enable_embedding_by_paragraph_signal if instance.get(
|
||||
'is_active') else ListenerManagement.disable_embedding_by_paragraph_signal)
|
||||
s.send(self.data.get('paragraph_id'))
|
||||
return self.one()
|
||||
return self.one(), instance
|
||||
|
||||
def get_problem_list(self):
|
||||
return [ProblemSerializer(problem).data for problem in
|
||||
|
|
|
|||
|
|
@ -68,17 +68,12 @@ class BaseVectorStore(ABC):
|
|||
:param trample_num 点踩数量
|
||||
:return: bool
|
||||
"""
|
||||
# 获取锁
|
||||
lock.acquire()
|
||||
try:
|
||||
if embedding is None:
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
self.save_pre_handler()
|
||||
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, star_num,
|
||||
trample_num, embedding)
|
||||
finally:
|
||||
# 释放锁
|
||||
lock.release()
|
||||
|
||||
if embedding is None:
|
||||
embedding = EmbeddingModel.get_embedding_model()
|
||||
self.save_pre_handler()
|
||||
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, star_num,
|
||||
trample_num, embedding)
|
||||
|
||||
def batch_save(self, data_list: List[Dict], embedding=None):
|
||||
# 获取锁
|
||||
|
|
|
|||
|
|
@ -16,35 +16,35 @@ prefix_dir = "/opt/maxkb/model"
|
|||
model_config = [
|
||||
{
|
||||
'download_params': {
|
||||
'cache_dir': os.path.join(prefix_dir, 'base'),
|
||||
'cache_dir': os.path.join(prefix_dir, 'base/hub'),
|
||||
'pretrained_model_name_or_path': 'gpt2'
|
||||
},
|
||||
'download_function': GPT2TokenizerFast.from_pretrained
|
||||
},
|
||||
{
|
||||
'download_params': {
|
||||
'cache_dir': os.path.join(prefix_dir, 'base'),
|
||||
'cache_dir': os.path.join(prefix_dir, 'base/hub'),
|
||||
'pretrained_model_name_or_path': 'gpt2-medium'
|
||||
},
|
||||
'download_function': GPT2TokenizerFast.from_pretrained
|
||||
},
|
||||
{
|
||||
'download_params': {
|
||||
'cache_dir': os.path.join(prefix_dir, 'base'),
|
||||
'cache_dir': os.path.join(prefix_dir, 'base/hub'),
|
||||
'pretrained_model_name_or_path': 'gpt2-large'
|
||||
},
|
||||
'download_function': GPT2TokenizerFast.from_pretrained
|
||||
},
|
||||
{
|
||||
'download_params': {
|
||||
'cache_dir': os.path.join(prefix_dir, 'base'),
|
||||
'cache_dir': os.path.join(prefix_dir, 'base/hub'),
|
||||
'pretrained_model_name_or_path': 'gpt2-xl'
|
||||
},
|
||||
'download_function': GPT2TokenizerFast.from_pretrained
|
||||
},
|
||||
{
|
||||
'download_params': {
|
||||
'cache_dir': os.path.join(prefix_dir, 'base'),
|
||||
'cache_dir': os.path.join(prefix_dir, 'base/hub'),
|
||||
'pretrained_model_name_or_path': 'distilgpt2'
|
||||
},
|
||||
'download_function': GPT2TokenizerFast.from_pretrained
|
||||
|
|
|
|||
Loading…
Reference in New Issue