feat: 优化对话逻辑

This commit is contained in:
shaohuzhang1 2024-01-16 16:46:54 +08:00
parent 7349f00c54
commit 3f87335c80
46 changed files with 1393 additions and 396 deletions

View File

@ -0,0 +1,55 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file I_base_chat_pipeline.py
@date2024/1/9 17:25
@desc:
"""
import time
from abc import abstractmethod
from typing import Type
from rest_framework import serializers
class IBaseChatPipelineStep:
def __init__(self):
# 当前步骤上下文,用于存储当前步骤信息
self.context = {}
@abstractmethod
def get_step_serializer(self, manage) -> Type[serializers.Serializer]:
pass
def valid_args(self, manage):
step_serializer_clazz = self.get_step_serializer(manage)
step_serializer = step_serializer_clazz(data=manage.context)
step_serializer.is_valid(raise_exception=True)
self.context['step_args'] = step_serializer.data
def run(self, manage):
"""
:param manage: 步骤管理器
:return: 执行结果
"""
start_time = time.time()
# 校验参数,
self.valid_args(manage)
self._run(manage)
self.context['start_time'] = start_time
self.context['run_time'] = time.time() - start_time
def _run(self, manage):
pass
def execute(self, **kwargs):
pass
def get_details(self, manage, **kwargs):
"""
运行详情
:return: 步骤详情
"""
return None

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 17:23
@desc:
"""

View File

@ -0,0 +1,45 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file pipeline_manage.py
@date2024/1/9 17:40
@desc:
"""
import time
from functools import reduce
from typing import List, Type, Dict
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
class PiplineManage:
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]]):
# 步骤执行器
self.step_list = [step() for step in step_list]
# 上下文
self.context = {'message_tokens': 0, 'answer_tokens': 0}
def run(self, context: Dict = None):
self.context['start_time'] = time.time()
if context is not None:
for key, value in context.items():
self.context[key] = value
for step in self.step_list:
step.run(self)
def get_details(self):
return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in
filter(lambda r: r is not None,
[row.get_details(self) for row in self.step_list])], {})
class builder:
def __init__(self):
self.step_list: List[Type[IBaseChatPipelineStep]] = []
def append_step(self, step: Type[IBaseChatPipelineStep]):
self.step_list.append(step)
return self
def build(self):
return PiplineManage(step_list=self.step_list)

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,88 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_chat_step.py
@date2024/1/9 18:17
@desc: 对话
"""
from abc import abstractmethod
from typing import Type, List
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.pipeline_manage import PiplineManage
from common.field.common import InstanceField
from dataset.models import Paragraph
class ModelField(serializers.Field):
def to_internal_value(self, data):
if not isinstance(data, BaseChatModel):
self.fail('模型类型错误', value=data)
return data
def to_representation(self, value):
return value
class MessageField(serializers.Field):
def to_internal_value(self, data):
if not isinstance(data, BaseMessage):
self.fail('message类型错误', value=data)
return data
def to_representation(self, value):
return value
class PostResponseHandler:
@abstractmethod
def handler(self, chat_id, chat_record_id, paragraph_list: List[Paragraph], problem_text: str, answer_text,
manage, step, padding_problem_text: str = None, **kwargs):
pass
class IChatStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 对话列表
message_list = serializers.ListField(required=True, child=MessageField(required=True))
# 大语言模型
chat_model = ModelField()
# 段落列表
paragraph_list = serializers.ListField()
# 对话id
chat_id = serializers.UUIDField(required=True)
# 用户问题
problem_text = serializers.CharField(required=True)
# 后置处理器
post_response_handler = InstanceField(model_type=PostResponseHandler)
# 补全问题
padding_problem_text = serializers.CharField(required=False)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
message_list: List = self.initial_data.get('message_list')
for message in message_list:
if not isinstance(message, BaseMessage):
raise Exception("message 类型错误")
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
chat_result = self.execute(**self.context['step_args'], manage=manage)
manage.context['chat_result'] = chat_result
@abstractmethod
def execute(self, message_list: List[BaseMessage],
chat_id, problem_text,
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
padding_problem_text: str = None, **kwargs):
pass

View File

@ -0,0 +1,111 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_chat_step.py
@date2024/1/9 18:25
@desc: 对话step Base实现
"""
import json
import logging
import time
import traceback
import uuid
from typing import List
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessageChunk, HumanMessage
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from dataset.models import Paragraph
def event_content(response,
chat_id,
chat_record_id,
paragraph_list: List[Paragraph],
post_response_handler: PostResponseHandler,
manage,
step,
chat_model,
message_list: List[BaseMessage],
problem_text: str,
padding_problem_text: str = None):
all_text = ''
try:
for chunk in response:
all_text += chunk.content
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': chunk.content, 'is_end': False}) + "\n\n"
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '', 'is_end': True}) + "\n\n"
# 获取token
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text)
step.context['message_tokens'] = request_token
step.context['answer_tokens'] = response_token
current_time = time.time()
step.context['answer_text'] = all_text
step.context['run_time'] = current_time - step.context['start_time']
manage.context['run_time'] = current_time - manage.context['start_time']
manage.context['message_tokens'] = manage.context['message_tokens'] + request_token
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '异常' + str(e), 'is_end': True}) + "\n\n"
class BaseChatStep(IChatStep):
def execute(self, message_list: List[BaseMessage],
chat_id,
problem_text,
post_response_handler: PostResponseHandler,
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PiplineManage = None,
padding_problem_text: str = None,
**kwargs):
# 调用模型
if chat_model is None:
chat_result = iter(
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
else:
chat_result = chat_model.stream(message_list)
chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
return r
def get_details(self, manage, **kwargs):
return {
'step_type': 'chat_step',
'run_time': self.context['run_time'],
'model_id': str(manage.context['model_id']),
'message_list': self.reset_message_list(self.context['step_args'].get('message_list'),
self.context['answer_text']),
'message_tokens': self.context['message_tokens'],
'answer_tokens': self.context['answer_tokens'],
'cost': 0,
}
@staticmethod
def reset_message_list(message_list: List[BaseMessage], answer_text):
result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for
message
in
message_list]
result.append({'role': 'ai', 'content': answer_text})
return result

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,68 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_generate_human_message_step.py
@date2024/1/9 18:15
@desc: 生成对话模板
"""
from abc import abstractmethod
from typing import Type, List
from langchain.schema import BaseMessage
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.models import ChatRecord
from common.field.common import InstanceField
from dataset.models import Paragraph
class IGenerateHumanMessageStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 问题
problem_text = serializers.CharField(required=True)
# 段落列表
paragraph_list = serializers.ListField(child=InstanceField(model_type=Paragraph, required=True))
# 历史对答
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True)
# 最大携带知识库段落长度
max_paragraph_char_number = serializers.IntegerField(required=True)
# 模板
prompt = serializers.CharField(required=True)
# 补齐问题
padding_problem_text = serializers.CharField(required=False)
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
message_list = self.execute(**self.context['step_args'])
manage.context['message_list'] = message_list
@abstractmethod
def execute(self,
problem_text: str,
paragraph_list: List[Paragraph],
history_chat_record: List[ChatRecord],
dialogue_number: int,
max_paragraph_char_number: int,
prompt: str,
padding_problem_text: str = None,
**kwargs) -> List[BaseMessage]:
"""
:param problem_text: 原始问题文本
:param paragraph_list: 段落列表
:param history_chat_record: 历史对话记录
:param dialogue_number: 多轮对话数量
:param max_paragraph_char_number: 最大段落长度
:param prompt: 模板
:param padding_problem_text 用户修改文本
:param kwargs: 其他参数
:return:
"""
pass

View File

@ -0,0 +1,57 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_generate_human_message_step.py.py
@date2024/1/10 17:50
@desc:
"""
from typing import List
from langchain.schema import BaseMessage, HumanMessage
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
IGenerateHumanMessageStep
from application.models import ChatRecord
from common.util.split_model import flat_map
from dataset.models import Paragraph
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
def execute(self, problem_text: str,
paragraph_list: List[Paragraph],
history_chat_record: List[ChatRecord],
dialogue_number: int,
max_paragraph_char_number: int,
prompt: str,
padding_problem_text: str = None,
**kwargs) -> List[BaseMessage]:
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
start_index = len(history_chat_record) - dialogue_number
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
return [*flat_map(history_message),
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list)]
@staticmethod
def to_human_message(prompt: str,
problem: str,
max_paragraph_char_number: int,
paragraph_list: List[Paragraph]):
if paragraph_list is None or len(paragraph_list) == 0:
return HumanMessage(content=problem)
temp_data = ""
data_list = []
for p in paragraph_list:
content = f"{p.title}:{p.content}"
temp_data += content
if len(temp_data) > max_paragraph_char_number:
row_data = content[0:max_paragraph_char_number - len(temp_data)]
data_list.append(f"<data>{row_data}</data>")
break
else:
data_list.append(f"<data>{content}</data>")
data = "\n".join(data_list)
return HumanMessage(content=prompt.format(**{'data': data, 'question': problem}))

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:23
@desc:
"""

View File

@ -0,0 +1,49 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_reset_problem_step.py
@date2024/1/9 18:12
@desc: 重写处理问题
"""
from abc import abstractmethod
from typing import Type, List
from langchain.chat_models.base import BaseChatModel
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.step.chat_step.i_chat_step import ModelField
from application.models import ChatRecord
from common.field.common import InstanceField
class IResetProblemStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 问题文本
problem_text = serializers.CharField(required=True)
# 历史对答
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True))
# 大语言模型
chat_model = ModelField()
def get_step_serializer(self, manage: PiplineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
padding_problem = self.execute(**self.context.get('step_args'))
# 用户输入问题
source_problem_text = self.context.get('step_args').get('problem_text')
self.context['problem_text'] = source_problem_text
self.context['padding_problem_text'] = padding_problem
manage.context['problem_text'] = source_problem_text
manage.context['padding_problem_text'] = padding_problem
# 累加tokens
manage.context['message_tokens'] = manage.context['message_tokens'] + self.context.get('message_tokens')
manage.context['answer_tokens'] = manage.context['answer_tokens'] + self.context.get('answer_tokens')
@abstractmethod
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
**kwargs):
pass

View File

@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_reset_problem_step.py
@date2024/1/10 14:35
@desc:
"""
from typing import List
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage
from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep
from application.models import ChatRecord
from common.util.split_model import flat_map
prompt = (
'()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在<data></data>标签中')
class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
**kwargs) -> str:
start_index = len(history_chat_record) - 3
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
message_list = [*flat_map(history_message),
HumanMessage(content=prompt.format(**{'question': problem_text}))]
response = chat_model(message_list)
padding_problem = response.content[response.content.index('<data>') + 6:response.content.index('</data>')]
self.context['message_tokens'] = chat_model.get_num_tokens_from_messages(message_list)
self.context['answer_tokens'] = chat_model.get_num_tokens(padding_problem)
return padding_problem
def get_details(self, manage, **kwargs):
return {
'step_type': 'problem_padding',
'run_time': self.context['run_time'],
'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None,
'message_tokens': self.context['message_tokens'],
'answer_tokens': self.context['answer_tokens'],
'cost': 0,
'padding_problem_text': self.context.get('padding_problem_text'),
'problem_text': self.context.get("step_args").get('problem_text'),
}

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/1/9 18:24
@desc:
"""

View File

@ -0,0 +1,58 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file i_search_dataset_step.py
@date2024/1/9 18:10
@desc: 检索知识库
"""
from abc import abstractmethod
from typing import List, Type
from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from application.chat_pipeline.pipeline_manage import PiplineManage
from dataset.models import Paragraph
class ISearchDatasetStep(IBaseChatPipelineStep):
class InstanceSerializer(serializers.Serializer):
# 原始问题文本
problem_text = serializers.CharField(required=True)
# 系统补全问题文本
padding_problem_text = serializers.CharField(required=False)
# 需要查询的数据集id列表
dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
# 需要排除的文档id
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
# 需要排除向量id
exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True))
# 需要查询的条数
top_n = serializers.IntegerField(required=True)
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
def get_step_serializer(self, manage: PiplineManage) -> Type[InstanceSerializer]:
return self.InstanceSerializer
def _run(self, manage: PiplineManage):
paragraph_list = self.execute(**self.context['step_args'])
manage.context['paragraph_list'] = paragraph_list
@abstractmethod
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
**kwargs) -> List[Paragraph]:
"""
关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询
:param similarity: 相关性
:param top_n: 查询多少条
:param problem_text: 用户问题
:param dataset_id_list: 需要查询的数据集id列表
:param exclude_document_id_list: 需要排除的文档id
:param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题
:return: 段落列表
"""
pass

View File

@ -0,0 +1,58 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file base_search_dataset_step.py
@date2024/1/10 10:33
@desc:
"""
from typing import List
from django.db.models import QuerySet
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
from common.config.embedding_config import VectorStore, EmbeddingModel
from dataset.models import Paragraph
class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
**kwargs) -> List[Paragraph]:
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
embedding_model = EmbeddingModel.get_embedding_model()
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(embedding_value, dataset_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity)
if embedding_list is None:
return []
return self.list_paragraph([row.get('paragraph_id') for row in embedding_list], vector)
@staticmethod
def list_paragraph(paragraph_id_list: List, vector):
if paragraph_id_list is None or len(paragraph_id_list) == 0:
return []
paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list)
# 如果向量库中存在脏数据 直接删除
if len(paragraph_list) != len(paragraph_id_list):
exist_paragraph_list = [str(row.id) for row in paragraph_list]
for paragraph_id in paragraph_id_list:
if not exist_paragraph_list.__contains__(paragraph_id):
vector.delete_by_paragraph_id(paragraph_id)
return paragraph_list
def get_details(self, manage, **kwargs):
step_args = self.context['step_args']
return {
'step_type': 'search_step',
'run_time': self.context['run_time'],
'problem_text': step_args.get(
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
'model_name': EmbeddingModel.get_embedding_model().model_name,
'message_tokens': 0,
'answer_tokens': 0,
'cost': 0
}

View File

@ -0,0 +1,55 @@
# Generated by Django 4.1.10 on 2024-01-12 18:46
import django.contrib.postgres.fields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0002_alter_chatrecord_dataset'),
]
operations = [
migrations.RemoveField(
model_name='chatrecord',
name='dataset',
),
migrations.RemoveField(
model_name='chatrecord',
name='paragraph',
),
migrations.RemoveField(
model_name='chatrecord',
name='source_id',
),
migrations.RemoveField(
model_name='chatrecord',
name='source_type',
),
migrations.AddField(
model_name='chatrecord',
name='const',
field=models.IntegerField(default=0, verbose_name='总费用'),
),
migrations.AddField(
model_name='chatrecord',
name='details',
field=models.JSONField(default=list, verbose_name='对话详情'),
),
migrations.AddField(
model_name='chatrecord',
name='paragraph_id_list',
field=django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, size=None, verbose_name='引用段落id列表'),
),
migrations.AddField(
model_name='chatrecord',
name='run_time',
field=models.FloatField(default=0, verbose_name='运行时长'),
),
migrations.AlterField(
model_name='chatrecord',
name='answer_text',
field=models.CharField(max_length=4096, verbose_name='答案'),
),
]

View File

@ -0,0 +1,38 @@
# Generated by Django 4.1.10 on 2024-01-15 16:07
import application.models.application
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0003_remove_chatrecord_dataset_and_more'),
]
operations = [
migrations.RemoveField(
model_name='application',
name='example',
),
migrations.AddField(
model_name='application',
name='dataset_setting',
field=models.JSONField(default=application.models.application.get_dataset_setting_dict, verbose_name='数据集参数设置'),
),
migrations.AddField(
model_name='application',
name='model_setting',
field=models.JSONField(default=application.models.application.get_model_setting_dict, verbose_name='模型参数相关设置'),
),
migrations.AddField(
model_name='application',
name='problem_optimization',
field=models.BooleanField(default=False, verbose_name='问题优化'),
),
migrations.AlterField(
model_name='chatrecord',
name='details',
field=models.JSONField(default={}, verbose_name='对话详情'),
),
]

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.10 on 2024-01-16 11:22
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0004_remove_application_example_and_more'),
]
operations = [
migrations.AlterField(
model_name='chatrecord',
name='details',
field=models.JSONField(default=dict, verbose_name='对话详情'),
),
]

View File

@ -10,23 +10,47 @@ import uuid
from django.contrib.postgres.fields import ArrayField
from django.db import models
from langchain.schema import HumanMessage, AIMessage
from common.mixins.app_model_mixin import AppModelMixin
from dataset.models.data_set import DataSet, Paragraph
from embedding.models import SourceType
from dataset.models.data_set import DataSet
from setting.models.model_management import Model
from users.models import User
def get_dataset_setting_dict():
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000}
def get_model_setting_dict():
return {'prompt': Application.get_default_model_prompt()}
class Application(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
name = models.CharField(max_length=128, verbose_name="应用名称")
desc = models.CharField(max_length=128, verbose_name="引用描述", default="")
prologue = models.CharField(max_length=1024, verbose_name="开场白", default="")
example = ArrayField(verbose_name="示例列表", base_field=models.CharField(max_length=256, blank=True), default=list)
dialogue_number = models.IntegerField(default=0, verbose_name="会话数量")
user = models.ForeignKey(User, on_delete=models.DO_NOTHING)
model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True)
dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict)
model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict)
problem_optimization = models.BooleanField(verbose_name="问题优化", default=False)
@staticmethod
def get_default_model_prompt():
return ('已知信息:'
'\n{data}'
'\n回答要求:'
'\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。'
'\n- 避免提及你是从<data></data>中获得的知识。'
'\n- 请保持答案与<data></data>中描述的一致。'
'\n- 请使用markdown 语法优化答案的格式。'
'\n- <data></data>中的图片链接、链接地址和脚本语言请完整返回。'
'\n- 请使用与问题相同的语言来回答。'
'\n问题:'
'\n{question}')
class Meta:
db_table = "application"
@ -65,20 +89,28 @@ class ChatRecord(AppModelMixin):
chat = models.ForeignKey(Chat, on_delete=models.CASCADE)
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
default=VoteChoices.UN_VOTE)
dataset = models.ForeignKey(DataSet, on_delete=models.SET_NULL, verbose_name="数据集", blank=True, null=True)
paragraph = models.ForeignKey(Paragraph, on_delete=models.SET_NULL, verbose_name="段落id", blank=True, null=True)
source_id = models.UUIDField(max_length=128, verbose_name="资源id 段落/问题 id ", null=True)
source_type = models.CharField(verbose_name='资源类型', max_length=2, choices=SourceType.choices,
default=SourceType.PROBLEM, blank=True, null=True)
paragraph_id_list = ArrayField(verbose_name="引用段落id列表",
base_field=models.UUIDField(max_length=128, blank=True)
, default=list)
problem_text = models.CharField(max_length=1024, verbose_name="问题")
answer_text = models.CharField(max_length=4096, verbose_name="答案")
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
problem_text = models.CharField(max_length=1024, verbose_name="问题")
answer_text = models.CharField(max_length=1024, verbose_name="答案")
const = models.IntegerField(verbose_name="总费用", default=0)
details = models.JSONField(verbose_name="对话详情", default=dict)
improve_paragraph_id_list = ArrayField(verbose_name="改进标注列表",
base_field=models.UUIDField(max_length=128, blank=True)
, default=list)
run_time = models.FloatField(verbose_name="运行时长", default=0)
index = models.IntegerField(verbose_name="对话下标")
def get_human_message(self):
if 'problem_padding' in self.details:
return HumanMessage(content=self.details.get('problem_padding').get('padding_problem_text'))
return HumanMessage(content=self.problem_text)
def get_ai_message(self):
return AIMessage(content=self.answer_text)
class Meta:
db_table = "application_chat_record"

View File

@ -63,15 +63,35 @@ class ApplicationSerializerModel(serializers.ModelSerializer):
fields = "__all__"
class DatasetSettingSerializer(serializers.Serializer):
top_n = serializers.FloatField(required=True)
similarity = serializers.FloatField(required=True, max_value=1, min_value=0)
max_paragraph_char_number = serializers.IntegerField(required=True, max_value=10000)
class ModelSettingSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, max_length=4096)
class ApplicationSerializer(serializers.Serializer):
name = serializers.CharField(required=True)
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True)
model_id = serializers.CharField(required=True)
multiple_rounds_dialogue = serializers.BooleanField(required=True)
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True), allow_null=True)
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True),
allow_null=True)
# 数据集相关设置
dataset_setting = DatasetSettingSerializer(required=True)
# 模型相关设置
model_setting = ModelSettingSerializer(required=True)
# 问题补全
problem_optimization = serializers.BooleanField(required=True)
def is_valid(self, *, user_id=None, raise_exception=False):
super().is_valid(raise_exception=True)
ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
class AccessTokenSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
@ -135,13 +155,13 @@ class ApplicationSerializer(serializers.Serializer):
model_id = serializers.CharField(required=False)
multiple_rounds_dialogue = serializers.BooleanField(required=False)
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True)
example = serializers.ListSerializer(required=False, child=serializers.CharField(required=True))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
def is_valid(self, *, user_id=None, raise_exception=False):
super().is_valid(raise_exception=True)
ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'),
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
# 数据集相关设置
dataset_setting = serializers.JSONField(required=False, allow_null=True)
# 模型相关设置
model_setting = serializers.JSONField(required=False, allow_null=True)
# 问题补全
problem_optimization = serializers.BooleanField(required=False, allow_null=True)
class Create(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
@ -168,9 +188,12 @@ class ApplicationSerializer(serializers.Serializer):
@staticmethod
def to_application_model(user_id: str, application: Dict):
return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'),
prologue=application.get('prologue'), example=application.get('example'),
prologue=application.get('prologue'),
dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0,
user_id=user_id, model_id=application.get('model_id'),
dataset_setting=application.get('dataset_setting'),
model_setting=application.get('model_setting'),
problem_optimization=application.get('problem_optimization')
)
@staticmethod
@ -267,7 +290,7 @@ class ApplicationSerializer(serializers.Serializer):
class ApplicationModel(serializers.ModelSerializer):
class Meta:
model = Application
fields = ['id', 'name', 'desc', 'prologue', 'example', 'dialogue_number']
fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number']
class Operate(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
@ -317,8 +340,9 @@ class ApplicationSerializer(serializers.Serializer):
model = QuerySet(Model).get(id=instance.get('model_id') if 'model_id' in instance else application.model_id)
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'example', 'status',
'api_key_is_active']
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
'dataset_setting', 'model_setting', 'problem_optimization'
'api_key_is_active']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'multiple_rounds_dialogue':

View File

@ -7,194 +7,181 @@
@desc:
"""
import json
import uuid
from typing import List
from uuid import UUID
from django.db.models import QuerySet
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage
from rest_framework import serializers, status
from django.core.cache import cache
from common import event
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.response import result
from dataset.models import Paragraph
from embedding.models import SourceType
from setting.models.model_management import Model
from django.db.models import QuerySet
from langchain.chat_models.base import BaseChatModel
from rest_framework import serializers
from application.chat_pipeline.pipeline_manage import PiplineManage
from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep
from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \
BaseGenerateHumanMessageStep
from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping
from common.exception.app_exception import AppApiException
from common.util.rsa_util import decrypt
from common.util.split_model import flat_map
from dataset.models import Paragraph, Document
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
chat_cache = cache
class MessageManagement:
@staticmethod
def get_message(title: str, content: str, message: str):
if content is None:
return HumanMessage(content=message)
return HumanMessage(content=(
f'已知信息:{title}:{content} '
'根据上述已知信息,请简洁和专业的来回答用户的问题。已知信息中的图片、链接地址和脚本语言请直接返回。如果无法从已知信息中得到答案,请说 “没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作” 或 “根据已知信息无法回答该问题,建议联系官方技术支持人员”,不允许在答案中添加编造成分,答案请使用中文。'
f'问题是:{message}'))
class ChatMessage:
def __init__(self, id: str, problem: str, title: str, paragraph: str, embedding_id: str, dataset_id: str,
document_id: str,
paragraph_id,
source_type: SourceType,
source_id: str,
answer: str,
message_tokens: int,
answer_token: int,
chat_model=None,
chat_message=None):
self.id = id
self.problem = problem
self.title = title
self.paragraph = paragraph
self.embedding_id = embedding_id
self.dataset_id = dataset_id
self.document_id = document_id
self.paragraph_id = paragraph_id
self.source_type = source_type
self.source_id = source_id
self.answer = answer
self.message_tokens = message_tokens
self.answer_token = answer_token
self.chat_model = chat_model
self.chat_message = chat_message
def get_chat_message(self):
return MessageManagement.get_message(self.problem, self.paragraph, self.problem)
class ChatInfo:
def __init__(self,
chat_id: str,
model: Model,
chat_model: BaseChatModel,
application_id: str | None,
dataset_id_list: List[str],
exclude_document_id_list: list[str],
dialogue_number: int):
application: Application):
"""
:param chat_id: 对话id
:param chat_model: 对话模型
:param dataset_id_list: 数据集列表
:param exclude_document_id_list: 排除的文档
:param application: 应用信息
"""
self.chat_id = chat_id
self.application_id = application_id
self.model = model
self.application = application
self.chat_model = chat_model
self.dataset_id_list = dataset_id_list
self.exclude_document_id_list = exclude_document_id_list
self.dialogue_number = dialogue_number
self.chat_message_list: List[ChatMessage] = []
self.chat_record_list: List[ChatRecord] = []
def append_chat_message(self, chat_message: ChatMessage):
self.chat_message_list.append(chat_message)
if self.application_id is not None:
def to_base_pipeline_manage_params(self):
dataset_setting = self.application.dataset_setting
model_setting = self.application.model_setting
return {
'dataset_id_list': self.dataset_id_list,
'exclude_document_id_list': self.exclude_document_id_list,
'exclude_paragraph_id_list': [],
'top_n': dataset_setting.get('top_n') if 'top_n' in dataset_setting else 3,
'similarity': dataset_setting.get('similarity') if 'similarity' in dataset_setting else 0.6,
'max_paragraph_char_number': dataset_setting.get(
'max_paragraph_char_number') if 'max_paragraph_char_number' in dataset_setting else 5000,
'history_chat_record': self.chat_record_list,
'chat_id': self.chat_id,
'dialogue_number': self.application.dialogue_number,
'prompt': model_setting.get(
'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(),
'chat_model': self.chat_model,
'model_id': self.application.model.id if self.application.model is not None else None,
'problem_optimization': self.application.problem_optimization
}
def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
exclude_paragraph_id_list):
params = self.to_base_pipeline_manage_params()
return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
'exclude_paragraph_id_list': exclude_paragraph_id_list}
def append_chat_record(self, chat_record: ChatRecord):
# 存入缓存中
self.chat_record_list.append(chat_record)
if self.application.id is not None:
# 插入数据库
event.ListenerChatMessage.record_chat_message_signal.send(
event.RecordChatMessageArgs(len(self.chat_message_list) - 1, self.chat_id, self.application_id,
chat_message)
)
# 异步更新token
event.ListenerChatMessage.update_chat_message_token_signal.send(chat_message)
if not QuerySet(Chat).filter(id=self.chat_id).exists():
Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text).save()
# 插入会话记录
chat_record.save()
def get_context_message(self):
start_index = len(self.chat_message_list) - self.dialogue_number
return [self.chat_message_list[index].get_chat_message() for index in
range(start_index if start_index > 0 else 0, len(self.chat_message_list))]
def get_post_handler(chat_info: ChatInfo):
class PostHandler(PostResponseHandler):
def handler(self,
chat_id: UUID,
chat_record_id,
paragraph_list: List[Paragraph],
problem_text: str,
answer_text,
manage: PiplineManage,
step: BaseChatStep,
padding_problem_text: str = None,
**kwargs):
chat_record = ChatRecord(id=chat_record_id,
chat_id=chat_id,
paragraph_id_list=[str(p.id) for p in paragraph_list],
problem_text=problem_text,
answer_text=answer_text,
details=manage.get_details(),
message_tokens=manage.context['message_tokens'],
answer_tokens=manage.context['answer_tokens'],
run_time=manage.context['run_time'],
index=len(chat_info.chat_record_list) + 1)
chat_info.append_chat_record(chat_record)
# 重新设置缓存
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
return PostHandler()
class ChatMessageSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
def chat(self, message):
def chat(self, message, re_chat: bool):
self.is_valid(raise_exception=True)
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
if chat_info is None:
return result.Result(response_status=status.HTTP_404_NOT_FOUND, code=404, message="会话过期")
chat_info = self.re_open_chat(chat_id)
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
chat_model = chat_info.chat_model
vector = VectorStore.get_embedding_vector()
# 向量库检索
_value = vector.search(message, chat_info.dataset_id_list, chat_info.exclude_document_id_list,
[chat_message.embedding_id for chat_message in
(list(filter(lambda row: row.problem == message, chat_info.chat_message_list)))],
True,
EmbeddingModel.get_embedding_model())
# 查询段落id详情
paragraph = None
if _value is not None:
paragraph = QuerySet(Paragraph).get(id=_value.get('paragraph_id'))
if paragraph is None:
vector.delete_by_paragraph_id(_value.get('paragraph_id'))
pipline_manage_builder = PiplineManage.builder()
# 如果开启了问题优化,则添加上问题优化步骤
if chat_info.application.problem_optimization:
pipline_manage_builder.append_step(BaseResetProblemStep)
# 构建流水线管理器
pipline_message = (pipline_manage_builder.append_step(BaseSearchDatasetStep)
.append_step(BaseGenerateHumanMessageStep)
.append_step(BaseChatStep)
.build())
exclude_paragraph_id_list = []
# 相同问题是否需要排除已经查询到的段落
if re_chat:
paragraph_id_list = flat_map([row.paragraph_id_list for row in
filter(lambda chat_record: chat_record == message,
chat_info.chat_record_list)])
exclude_paragraph_id_list = list(set(paragraph_id_list))
# 构建运行参数
params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list)
# 运行流水线作业
pipline_message.run(params)
return pipline_message.context['chat_result']
title, content = (None, None) if paragraph is None else (paragraph.title, paragraph.content)
_id = str(uuid.uuid1())
@staticmethod
def re_open_chat(chat_id: str):
chat = QuerySet(Chat).filter(id=chat_id).first()
if chat is None:
raise AppApiException(500, "会话不存在")
application = QuerySet(Application).filter(id=chat.application_id).first()
if application is None:
raise AppApiException(500, "应用不存在")
model = QuerySet(Model).filter(id=application.model_id).first()
chat_model = None
if model is not None:
# 对话模型
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
# 数据集id列表
dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter(
application_id=application.id)]
embedding_id, dataset_id, document_id, paragraph_id, source_type, source_id = (_value.get(
'id'), _value.get(
'dataset_id'), _value.get(
'document_id'), _value.get(
'paragraph_id'), _value.get(
'source_type'), _value.get(
'source_id')) if _value is not None else (None, None, None, None, None, None)
if chat_model is None:
def event_block_content(c: str):
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
'is_end': True,
'content': c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~'}) + "\n\n"
chat_info.append_chat_message(
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
paragraph_id,
source_type,
source_id,
c if c is not None else '抱歉,根据已知信息无法回答这个问题,请重新描述您的问题或提供更多信息~',
0,
0))
# 重新设置缓存
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
r = StreamingHttpResponse(streaming_content=event_block_content(content),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
return r
# 获取上下文
history_message = chat_info.get_context_message()
# 构建会话请求问题
chat_message = [*history_message, MessageManagement.get_message(title, content, message)]
# 对话
result_data = chat_model.stream(chat_message)
def event_content(response):
all_text = ''
try:
for chunk in response:
all_text += chunk.content
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
'content': chunk.content, 'is_end': False}) + "\n\n"
yield 'data: ' + json.dumps({'chat_id': chat_id, 'id': _id, 'operate': paragraph is not None,
'content': '', 'is_end': True}) + "\n\n"
chat_info.append_chat_message(
ChatMessage(_id, message, title, content, embedding_id, dataset_id, document_id,
paragraph_id,
source_type,
source_id, all_text,
0,
0,
chat_message=chat_message, chat_model=chat_model))
# 重新设置缓存
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
except Exception as e:
yield e
r = StreamingHttpResponse(streaming_content=event_content(result_data),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
return r
# 需要排除的文档
exclude_document_id_list = [str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)]
return ChatInfo(chat_id, chat_model, dataset_id_list, exclude_document_id_list, application)

View File

@ -10,7 +10,8 @@ import datetime
import json
import os
import uuid
from typing import Dict
from functools import reduce
from typing import Dict, List
from django.core.cache import cache
from django.db import transaction
@ -18,7 +19,8 @@ from django.db.models import QuerySet
from rest_framework import serializers
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
from application.serializers.application_serializers import ModelDatasetAssociation
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
ModelSettingSerializer
from application.serializers.chat_message_serializers import ChatInfo
from common.db.search import native_search, native_page_search, page_search
from common.event import ListenerManagement
@ -26,8 +28,8 @@ from common.exception.app_exception import AppApiException
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import decrypt
from common.util.split_model import flat_map
from dataset.models import Document, Problem, Paragraph
from embedding.models import SourceType, Embedding
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR
@ -106,12 +108,12 @@ class ChatSerializers(serializers.Serializer):
chat_id = str(uuid.uuid1())
chat_cache.set(chat_id,
ChatInfo(chat_id, model, chat_model, application_id, dataset_id_list,
ChatInfo(chat_id, chat_model, dataset_id_list,
[str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)],
application.dialogue_number), timeout=60 * 30)
application), timeout=60 * 30)
return chat_id
class OpenTempChat(serializers.Serializer):
@ -122,6 +124,12 @@ class ChatSerializers(serializers.Serializer):
multiple_rounds_dialogue = serializers.BooleanField(required=True)
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
# 数据集相关设置
dataset_setting = DatasetSettingSerializer(required=True)
# 模型相关设置
model_setting = ModelSettingSerializer(required=True)
# 问题补全
problem_optimization = serializers.BooleanField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -140,42 +148,62 @@ class ChatSerializers(serializers.Serializer):
json.loads(
decrypt(model.credential)),
streaming=True)
application = Application(id=None, dialogue_number=3, model=model,
dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'),
problem_optimization=self.data.get('problem_optimization'))
chat_cache.set(chat_id,
ChatInfo(chat_id, model, chat_model, None, dataset_id_list,
ChatInfo(chat_id, chat_model, dataset_id_list,
[str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)],
3 if self.data.get('multiple_rounds_dialogue') else 1), timeout=60 * 30)
application), timeout=60 * 30)
return chat_id
def vote_exec(source_type: SourceType, source_id: str, field: str, post_handler):
if source_type == SourceType.PROBLEM:
problem = QuerySet(Problem).get(id=source_id)
if problem is not None:
problem.__setattr__(field, post_handler(problem))
problem.save()
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
embedding.__setattr__(field, problem.__getattribute__(field))
embedding.save()
if source_type == SourceType.PARAGRAPH:
paragraph = QuerySet(Paragraph).get(id=source_id)
if paragraph is not None:
paragraph.__setattr__(field, post_handler(paragraph))
paragraph.save()
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
embedding.__setattr__(field, paragraph.__getattribute__(field))
embedding.save()
class ChatRecordSerializerModel(serializers.ModelSerializer):
class Meta:
model = ChatRecord
fields = "__all__"
fields = ['id', 'chat_id', 'vote_status', 'problem_text', 'answer_text',
'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index']
class ChatRecordSerializer(serializers.Serializer):
class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_record_id = serializers.UUIDField(required=True)
def one(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
chat_record_id = self.data.get('chat_record_id')
chat_id = self.data.get('chat_id')
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
dataset_list = []
paragraph_list = []
if len(chat_record.paragraph_id_list) > 0:
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=chat_record.paragraph_id_list),
get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
'list_dataset_paragraph_by_paragraph_id.sql')),
with_table_name=True)
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
[{row.get(
'dataset_id'): row.get(
"dataset_name")} for
row in
paragraph_list],
{}).items()]
return {
**ChatRecordSerializerModel(chat_record).data,
'padding_problem_text': chat_record.details.get(
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
'dataset_list': dataset_list,
'paragraph_list': paragraph_list}
class Query(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True)
@ -183,15 +211,57 @@ class ChatRecordSerializer(serializers.Serializer):
def list(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))
return [ChatRecordSerializerModel(chat_record).data for chat_record in
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))]
def reset_chat_record_list(self, chat_record_list: List[ChatRecord]):
paragraph_id_list = flat_map([chat_record.paragraph_id_list for chat_record in chat_record_list])
# 去重
paragraph_id_list = list(set(paragraph_id_list))
paragraph_list = self.search_paragraph(paragraph_id_list)
return [self.reset_chat_record(chat_record, paragraph_list) for chat_record in chat_record_list]
@staticmethod
def search_paragraph(paragraph_id_list: List[str]):
paragraph_list = []
if len(paragraph_id_list) > 0:
paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list),
get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql',
'list_dataset_paragraph_by_paragraph_id.sql')),
with_table_name=True)
return paragraph_list
@staticmethod
def reset_chat_record(chat_record, all_paragraph_list):
paragraph_list = list(
filter(lambda paragraph: chat_record.paragraph_id_list.__contains__(str(paragraph.get('id'))),
all_paragraph_list))
dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y},
[{row.get(
'dataset_id'): row.get(
"dataset_name")} for
row in
paragraph_list],
{}).items()]
return {
**ChatRecordSerializerModel(chat_record).data,
'padding_problem_text': chat_record.details.get('problem_padding').get(
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
'dataset_list': dataset_list,
'paragraph_list': paragraph_list
}
def page(self, current_page: int, page_size: int, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return page_search(current_page, page_size,
page = page_search(current_page, page_size,
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"),
post_records_handler=lambda chat_record: ChatRecordSerializerModel(chat_record).data)
post_records_handler=lambda chat_record: chat_record)
records = page.get('records')
page['records'] = self.reset_chat_record_list(records)
return page
class Vote(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
@ -216,38 +286,20 @@ class ChatRecordSerializer(serializers.Serializer):
if vote_status == VoteChoices.STAR:
# 点赞
chat_record_details_model.vote_status = VoteChoices.STAR
# 点赞数量 +1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'star_num',
lambda r: r.star_num + 1)
if vote_status == VoteChoices.TRAMPLE:
# 点踩
chat_record_details_model.vote_status = VoteChoices.TRAMPLE
# 点踩数量+1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'trample_num',
lambda r: r.trample_num + 1)
chat_record_details_model.save()
else:
if vote_status == VoteChoices.UN_VOTE:
# 取消点赞
chat_record_details_model.vote_status = VoteChoices.UN_VOTE
chat_record_details_model.save()
if chat_record_details_model.vote_status == VoteChoices.STAR:
# 点赞数量 -1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'star_num', lambda r: r.star_num - 1)
if chat_record_details_model.vote_status == VoteChoices.TRAMPLE:
# 点踩数量 -1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'trample_num', lambda r: r.trample_num - 1)
else:
raise AppApiException(500, "已经投票过,请先取消后再进行投票")
finally:
un_lock(self.data.get('chat_record_id'))
return True
class ImproveSerializer(serializers.Serializer):

View File

@ -1,4 +1,4 @@
SELECT * FROM ( SELECT * FROM application ${application_custom_sql} UNION
SELECT *,to_json(dataset_setting) as dataset_setting,to_json(model_setting) as model_setting FROM ( SELECT * FROM application ${application_custom_sql} UNION
SELECT
*
FROM

View File

@ -0,0 +1,6 @@
SELECT
paragraph.*,
dataset."name" AS "dataset_name"
FROM
paragraph paragraph
LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id

View File

@ -58,6 +58,7 @@ class ApplicationApi(ApiMixin):
'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'),
'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联知识库Id列表",
@ -133,7 +134,7 @@ class ApplicationApi(ApiMixin):
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'],
required=[],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
@ -141,11 +142,52 @@ class ApplicationApi(ApiMixin):
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
description="是否开启多轮对话"),
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
title="示例列表", description="示例列表"),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联知识库Id列表", description="关联知识库Id列表"),
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
description="是否开启问题优化", default=True)
}
)
class DatasetSetting(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=[''],
properties={
'top_n': openapi.Schema(type=openapi.TYPE_NUMBER, title="引用分段数", description="引用分段数",
default=5),
'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title='相似度', description="相似度",
default=0.6),
'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数',
description="最多引用字符数", default=3000),
}
)
class ModelSetting(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['prompt'],
properties={
'prompt': openapi.Schema(type=openapi.TYPE_STRING, title="提示词", description="提示词",
default=('已知信息:'
'\n{data}'
'\n回答要求:'
'\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。'
'\n- 避免提及你是从<data></data>中获得的知识。'
'\n- 请保持答案与<data></data>中描述的一致。'
'\n- 请使用markdown 语法优化答案的格式。'
'\n- <data></data>中的图片链接、链接地址和脚本语言请完整返回。'
'\n- 请使用与问题相同的语言来回答。'
'\n问题:'
'\n{question}')),
}
)
@ -155,7 +197,8 @@ class ApplicationApi(ApiMixin):
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue'],
required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting',
'problem_optimization'],
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"),
@ -163,11 +206,13 @@ class ApplicationApi(ApiMixin):
"multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话",
description="是否开启多轮对话"),
'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"),
'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
title="示例列表", description="示例列表"),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联知识库Id列表", description="关联知识库Id列表")
title="关联知识库Id列表", description="关联知识库Id列表"),
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
description="是否开启问题优化", default=True)
}
)

View File

@ -8,6 +8,7 @@
"""
from drf_yasg import openapi
from application.swagger_api.application_api import ApplicationApi
from common.mixins.api_mixin import ApiMixin
@ -19,6 +20,7 @@ class ChatApi(ApiMixin):
required=['message'],
properties={
'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"),
're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default="重新生成")
}
)
@ -73,14 +75,19 @@ class ChatApi(ApiMixin):
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['model_id', 'multiple_rounds_dialogue'],
required=['model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting',
'problem_optimization'],
properties={
'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"),
'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title="关联知识库Id列表", description="关联知识库Id列表"),
'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话",
description="是否开启多轮会话")
description="是否开启多轮会话"),
'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(),
'model_setting': ApplicationApi.ModelSetting.get_request_body_api(),
'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化",
description="是否开启问题优化", default=True)
}
)

View File

@ -25,6 +25,8 @@ urlpatterns = [
path('application/<str:application_id>/chat/<chat_id>/chat_record/', views.ChatView.ChatRecord.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/<int:current_page>/<int:page_size>',
views.ChatView.ChatRecord.Page.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/<chat_record_id>',
views.ChatView.ChatRecord.Operate.as_view()),
path('application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/vote',
views.ChatView.ChatRecord.Vote.as_view(),
name=''),

View File

@ -71,7 +71,8 @@ class ChatView(APIView):
dynamic_tag=keywords.get('application_id'))])
)
def post(self, request: Request, chat_id: str):
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'))
return ChatMessageSerializer(data={'chat_id': chat_id}).chat(request.data.get('message'), request.data.get(
're_chat') if 're_chat' in request.data else False)
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话列表",
@ -134,6 +135,27 @@ class ChatView(APIView):
class ChatRecord(APIView):
authentication_classes = [TokenAuth]
class Operate(APIView):
authentication_classes = [TokenAuth]
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话记录详情",
operation_id="获取对话记录详情",
manual_parameters=ChatRecordApi.get_request_params_api(),
responses=result.get_api_array_response(ChatRecordApi.get_response_body_api()),
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)
def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str):
return result.success(ChatRecordSerializer.Operate(
data={'application_id': application_id,
'chat_id': chat_id,
'chat_record_id': chat_record_id}).one())
@action(methods=['GET'], detail=False)
@swagger_auto_schema(operation_summary="获取对话记录列表",
operation_id="获取对话记录列表",

View File

@ -83,7 +83,7 @@ def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, s
field_replace_dict = get_field_replace_dict(queryset)
app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection,
field_replace_dict=field_replace_dict)
sql, params = app_sql_compiler.get_query_str(with_table_name)
sql, params = app_sql_compiler.get_query_str(with_table_name=with_table_name)
return sql, params

View File

@ -7,10 +7,8 @@
@desc:
"""
from .listener_manage import *
from .listener_chat_message import *
def run():
listener_manage.ListenerManagement().run()
listener_chat_message.ListenerChatMessage().run()
QuerySet(Document).filter(status=Status.embedding).update(**{'status': Status.error})

View File

@ -1,67 +0,0 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file listener_manage.py
@date2023/10/20 14:01
@desc:
"""
import logging
from blinker import signal
from django.db.models import QuerySet
from application.models import ChatRecord, Chat
from application.serializers.chat_message_serializers import ChatMessage
from common.event.common import poxy
class RecordChatMessageArgs:
def __init__(self, index: int, chat_id: str, application_id: str, chat_message: ChatMessage):
self.index = index
self.chat_id = chat_id
self.application_id = application_id
self.chat_message = chat_message
class ListenerChatMessage:
record_chat_message_signal = signal("record_chat_message")
update_chat_message_token_signal = signal("update_chat_message_token")
@staticmethod
def record_chat_message(args: RecordChatMessageArgs):
if not QuerySet(Chat).filter(id=args.chat_id).exists():
Chat(id=args.chat_id, application_id=args.application_id, abstract=args.chat_message.problem).save()
# 插入会话记录
try:
chat_record = ChatRecord(
id=args.chat_message.id,
chat_id=args.chat_id,
dataset_id=args.chat_message.dataset_id,
paragraph_id=args.chat_message.paragraph_id,
source_id=args.chat_message.source_id,
source_type=args.chat_message.source_type,
problem_text=args.chat_message.problem,
answer_text=args.chat_message.answer,
index=args.index,
message_tokens=args.chat_message.message_tokens,
answer_tokens=args.chat_message.answer_token)
chat_record.save()
except Exception as e:
print(e)
@staticmethod
@poxy
def update_token(chat_message: ChatMessage):
if chat_message.chat_model is not None:
logging.getLogger("max_kb").info("开始更新token")
message_token = chat_message.chat_model.get_num_tokens_from_messages(chat_message.chat_message)
answer_token = chat_message.chat_model.get_num_tokens(chat_message.answer)
# 修改token数量
QuerySet(ChatRecord).filter(id=chat_message.id).update(
**{'message_tokens': message_token, 'answer_tokens': answer_token})
def run(self):
# 记录会话
ListenerChatMessage.record_chat_message_signal.connect(self.record_chat_message)
ListenerChatMessage.update_chat_message_token_signal.connect(self.update_token)

View File

@ -0,0 +1,34 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file common.py
@date2024/1/11 18:44
@desc:
"""
from rest_framework import serializers
class InstanceField(serializers.Field):
def __init__(self, model_type, **kwargs):
self.model_type = model_type
super().__init__(**kwargs)
def to_internal_value(self, data):
if not isinstance(data, self.model_type):
self.fail('message类型错误', value=data)
return data
def to_representation(self, value):
return value
class FunctionField(serializers.Field):
def to_internal_value(self, data):
if not callable(data):
self.fail('不是一個函數', value=data)
return data
def to_representation(self, value):
return value

View File

@ -5,9 +5,7 @@ SELECT
problem.dataset_id AS dataset_id,
0 AS source_type,
problem."content" AS "text",
paragraph.is_active AS is_active,
problem.star_num as star_num,
problem.trample_num as trample_num
paragraph.is_active AS is_active
FROM
problem problem
LEFT JOIN paragraph paragraph ON paragraph."id" = problem.paragraph_id
@ -23,9 +21,7 @@ SELECT
concat_ws('
',concat_ws('
',paragraph.title,paragraph."content"),paragraph.title) AS "text",
paragraph.is_active AS is_active,
paragraph.star_num as star_num,
paragraph.trample_num as trample_num
paragraph.is_active AS is_active
FROM
paragraph paragraph

View File

@ -0,0 +1,37 @@
# Generated by Django 4.1.10 on 2024-01-16 11:22
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('dataset', '0003_alter_paragraph_content'),
]
operations = [
migrations.RemoveField(
model_name='paragraph',
name='hit_num',
),
migrations.RemoveField(
model_name='paragraph',
name='star_num',
),
migrations.RemoveField(
model_name='paragraph',
name='trample_num',
),
migrations.RemoveField(
model_name='problem',
name='hit_num',
),
migrations.RemoveField(
model_name='problem',
name='star_num',
),
migrations.RemoveField(
model_name='problem',
name='trample_num',
),
]

View File

@ -74,9 +74,6 @@ class Paragraph(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING)
content = models.CharField(max_length=4096, verbose_name="段落内容")
title = models.CharField(max_length=256, verbose_name="标题", default="")
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
star_num = models.IntegerField(verbose_name="点赞数", default=0)
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices,
default=Status.embedding)
is_active = models.BooleanField(default=True)
@ -94,9 +91,6 @@ class Problem(AppModelMixin):
dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False)
paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False)
content = models.CharField(max_length=256, verbose_name="问题内容")
hit_num = models.IntegerField(verbose_name="命中数量", default=0)
star_num = models.IntegerField(verbose_name="点赞数", default=0)
trample_num = models.IntegerField(verbose_name="点踩数", default=0)
class Meta:
db_table = "problem"

View File

@ -26,12 +26,11 @@ from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list
from common.event.listener_manage import ListenerManagement, SyncWebDatasetArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork, ForkManage
from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type
from dataset.serializers.common_serializers import list_paragraph
@ -286,7 +285,8 @@ class DataSetSerializers(serializers.ModelSerializer):
properties={
'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"),
'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"),
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", description="web站点url"),
'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url",
description="web站点url"),
'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器")
}
)
@ -369,7 +369,8 @@ class DataSetSerializers(serializers.ModelSerializer):
dataset_id = uuid.uuid1()
dataset = DataSet(
**{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id,
'type': Type.web, 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
'type': Type.web,
'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector')}})
dataset.save()
ListenerManagement.sync_web_dataset_signal.send(
SyncWebDatasetArgs(str(dataset_id), instance.get('source_url'), instance.get('selector'),

View File

@ -28,7 +28,7 @@ from dataset.serializers.problem_serializers import ProblemInstanceSerializer, P
class ParagraphSerializer(serializers.ModelSerializer):
class Meta:
model = Paragraph
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'document_id', 'title',
fields = ['id', 'content', 'is_active', 'document_id', 'title',
'create_time', 'update_time']

View File

@ -24,7 +24,7 @@ from embedding.vector.pg_vector import PGVector
class ProblemSerializer(serializers.ModelSerializer):
class Meta:
model = Problem
fields = ['id', 'content', 'hit_num', 'star_num', 'trample_num', 'dataset_id', 'document_id',
fields = ['id', 'content', 'dataset_id', 'document_id',
'create_time', 'update_time']
@ -77,8 +77,8 @@ class ProblemSerializers(ApiMixin, serializers.Serializer):
'document_id': self.data.get('document_id'),
'paragraph_id': self.data.get('paragraph_id'),
'dataset_id': self.data.get('dataset_id'),
'star_num': 0,
'trample_num': 0})
})
return ProblemSerializers.Operate(
data={'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get('document_id'),

View File

@ -0,0 +1,37 @@
# Generated by Django 4.1.10 on 2024-01-16 11:22
import django.contrib.postgres.fields
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('embedding', '0001_initial'),
]
operations = [
migrations.RemoveField(
model_name='embedding',
name='star_num',
),
migrations.RemoveField(
model_name='embedding',
name='trample_num',
),
migrations.AddField(
model_name='embedding',
name='keywords',
field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(max_length=256), default=list, size=None, verbose_name='关键词列表'),
),
migrations.AddField(
model_name='embedding',
name='meta',
field=models.JSONField(default=dict, verbose_name='元数据'),
),
migrations.AlterField(
model_name='embedding',
name='source_type',
field=models.CharField(choices=[('0', '问题'), ('1', '段落'), ('2', '标题')], default='0', max_length=5, verbose_name='资源类型'),
),
]

View File

@ -6,6 +6,7 @@
@date2023/9/21 15:46
@desc:
"""
from django.contrib.postgres.fields import ArrayField
from django.db import models
from common.field.vector_field import VectorField
@ -16,6 +17,7 @@ class SourceType(models.TextChoices):
"""订单类型"""
PROBLEM = 0, '问题'
PARAGRAPH = 1, '段落'
TITLE = 2, '标题'
class Embedding(models.Model):
@ -36,10 +38,10 @@ class Embedding(models.Model):
embedding = VectorField(verbose_name="向量")
star_num = models.IntegerField(default=0, verbose_name="点赞数量")
keywords = ArrayField(verbose_name="关键词列表",
base_field=models.CharField(max_length=256), default=list)
trample_num = models.IntegerField(default=0,
verbose_name="点踩数量")
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta:
db_table = "embedding"

View File

@ -1,15 +1,17 @@
SELECT * FROM (SELECT
*,
( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
CASE
WHEN embedding.star_num - embedding.trample_num = 0 THEN
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
END AS score
SELECT
paragraph_id,
comprehensive_score,
comprehensive_score as similarity
FROM
embedding,
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
${embedding_query}
) temp
WHERE similarity>0.5
ORDER BY (similarity + score) DESC LIMIT 1
(
SELECT DISTINCT ON
("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
FROM
( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query}) TEMP
ORDER BY
paragraph_id,
similarity DESC
) DISTINCT_TEMP
WHERE comprehensive_score>%s
ORDER BY comprehensive_score DESC
LIMIT %s

View File

@ -1,34 +1,17 @@
SELECT
similarity,
paragraph_id,
comprehensive_score
comprehensive_score,
comprehensive_score as similarity
FROM
(
SELECT DISTINCT ON
( "paragraph_id" ) ( similarity + score ),*,
( similarity + score ) AS comprehensive_score
("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score
FROM
(
SELECT
*,
( 1 - ( embedding.embedding <=> %s ) ) AS similarity,
CASE
WHEN embedding.star_num - embedding.trample_num = 0 THEN
0 ELSE ( ( ( embedding.star_num - embedding.trample_num ) - aggs.min_value ) / ( aggs.max_value - aggs.min_value ) )
END AS score
FROM
embedding,
( SELECT MIN ( star_num - trample_num ) AS min_value, MAX ( star_num - trample_num ) AS max_value FROM embedding ${embedding_query}) aggs
${embedding_query}
) TEMP
WHERE
similarity > %s
( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query} ) TEMP
ORDER BY
paragraph_id,
( similarity + score )
DESC
) ss
ORDER BY
comprehensive_score DESC
LIMIT %s
similarity DESC
) DISTINCT_TEMP
WHERE comprehensive_score>%s
ORDER BY comprehensive_score DESC
LIMIT %s

View File

@ -99,8 +99,6 @@ class BaseVectorStore(ABC):
@abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
star_num: int,
trample_num: int,
embedding: HuggingFaceEmbeddings):
pass
@ -108,11 +106,20 @@ class BaseVectorStore(ABC):
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
pass
@abstractmethod
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_id_list: list[str],
exclude_paragraph_list: list[str],
is_active: bool,
embedding: HuggingFaceEmbeddings):
if dataset_id_list is None or len(dataset_id_list) == 0:
return []
embedding_query = embedding.embed_query(query_text)
result = self.query(embedding_query, dataset_id_list, exclude_document_id_list, exclude_paragraph_list,
is_active, 1, 0.65)
return result[0]
@abstractmethod
def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
pass
@abstractmethod

View File

@ -33,8 +33,6 @@ class PGVector(BaseVectorStore):
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
star_num: int,
trample_num: int,
embedding: HuggingFaceEmbeddings):
text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid1(),
@ -44,10 +42,7 @@ class PGVector(BaseVectorStore):
paragraph_id=paragraph_id,
source_id=source_id,
embedding=text_embedding,
source_type=source_type,
star_num=star_num,
trample_num=trample_num
)
source_type=source_type)
embedding.save()
return True
@ -61,8 +56,6 @@ class PGVector(BaseVectorStore):
is_active=text_list[index].get('is_active', True),
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
star_num=text_list[index].get('star_num'),
trample_num=text_list[index].get('trample_num'),
embedding=embeddings[index]) for index in
range(0, len(text_list))]
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
@ -78,29 +71,27 @@ class PGVector(BaseVectorStore):
'hit_test.sql')),
with_table_name=True)
embedding_model = select_list(exec_sql,
[json.dumps(embedding_query), *exec_params, *exec_params, similarity, top_number])
[json.dumps(embedding_query), *exec_params, similarity, top_number])
return embedding_model
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_id_list: list[str],
is_active: bool,
embedding: HuggingFaceEmbeddings):
def query(self, query_embedding: List[float], dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float):
exclude_dict = {}
if dataset_id_list is None or len(dataset_id_list) == 0:
return None
return []
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active)
embedding_query = embedding.embed_query(query_text)
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
if exclude_id_list is not None and len(exclude_id_list) > 0:
exclude_dict.__setitem__('id__in', exclude_id_list)
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list)
query_set = query_set.exclude(**exclude_dict)
exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "embedding", 'sql',
'embedding_search.sql')),
with_table_name=True)
embedding_model = select_one(exec_sql, (json.dumps(embedding_query), *exec_params, *exec_params))
embedding_model = select_list(exec_sql,
[json.dumps(query_embedding), *exec_params, similarity, top_n])
return embedding_model
def update_by_source_id(self, source_id: str, instance: Dict):

View File

@ -14,13 +14,23 @@ from langchain.chat_models.base import BaseChatModel
from langchain.load import dumpd
from langchain.schema import LLMResult
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage
from langchain.schema.messages import BaseMessageChunk, BaseMessage, HumanMessage, AIMessage, get_buffer_string
from langchain.schema.output import ChatGenerationChunk
from langchain.schema.runnable import RunnableConfig
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
force_download=False)
class QianfanChatModel(QianfanChatEndpoint):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
return len(tokenizer.encode(text))
def stream(
self,
input: LanguageModelInput,
@ -30,7 +40,7 @@ class QianfanChatModel(QianfanChatEndpoint):
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if len(input) % 2 == 0:
input = [HumanMessage(content='占位'), *input]
input = [HumanMessage(content='padding'), *input]
input = [
HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content)
for index in range(0, len(input))]