* feat: 增加无引用分段设置

This commit is contained in:
shaohuzhang1 2024-04-24 15:03:58 +08:00 committed by GitHub
parent a3f47102a6
commit 1f522dc551
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 218 additions and 48 deletions

View File

@ -15,9 +15,9 @@ from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.serializers.application_serializers import NoReferencesSetting
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
from dataset.models import Paragraph
class ModelField(serializers.Field):
@ -70,6 +70,8 @@ class IChatStep(IBaseChatPipelineStep):
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
@ -92,5 +94,6 @@ class IChatStep(IBaseChatPipelineStep):
chat_model: BaseChatModel = None,
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, **kwargs):
padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
no_references_setting=None, **kwargs):
pass

View File

@ -17,7 +17,8 @@ from django.db.models import QuerySet
from django.http import StreamingHttpResponse
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from langchain.schema.messages import BaseMessageChunk, HumanMessage, AIMessage
from langchain.schema.messages import HumanMessage, AIMessage
from langchain_core.messages import AIMessageChunk
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
@ -47,7 +48,8 @@ def event_content(response,
message_list: List[BaseMessage],
problem_text: str,
padding_problem_text: str = None,
client_id=None, client_type=None):
client_id=None, client_type=None,
is_ai_chat: bool = None):
all_text = ''
try:
for chunk in response:
@ -56,8 +58,12 @@ def event_content(response,
'content': chunk.content, 'is_end': False}) + "\n\n"
# 获取token
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text)
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text)
else:
request_token = 0
response_token = 0
step.context['message_tokens'] = request_token
step.context['answer_tokens'] = response_token
current_time = time.time()
@ -88,15 +94,16 @@ class BaseChatStep(IChatStep):
padding_problem_text: str = None,
stream: bool = True,
client_id=None, client_type=None,
no_references_setting=None,
**kwargs):
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type)
manage, padding_problem_text, client_id, client_type, no_references_setting)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, client_id, client_type)
manage, padding_problem_text, client_id, client_type, no_references_setting)
def get_details(self, manage, **kwargs):
return {
@ -127,19 +134,26 @@ class BaseChatStep(IChatStep):
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None):
client_id=None, client_type=None,
no_references_setting=None):
is_ai_chat = False
# 调用模型
if chat_model is None:
chat_result = iter(
[BaseMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
[AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
else:
chat_result = chat_model.stream(message_list)
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))])
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
padding_problem_text, client_id, client_type),
padding_problem_text, client_id, client_type, is_ai_chat),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
@ -153,16 +167,26 @@ class BaseChatStep(IChatStep):
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None):
client_id=None, client_type=None, no_references_setting=None):
is_ai_chat = False
# 调用模型
if chat_model is None:
chat_result = AIMessage(
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
else:
chat_result = chat_model.invoke(message_list)
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = AIMessage(content=no_references_setting.get('value'))
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
chat_record_id = uuid.uuid1()
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(chat_result.content)
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(chat_result.content)
else:
request_token = 0
response_token = 0
self.context['message_tokens'] = request_token
self.context['answer_tokens'] = response_token
current_time = time.time()

View File

@ -15,9 +15,9 @@ from rest_framework import serializers
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.models import ChatRecord
from application.serializers.application_serializers import NoReferencesSetting
from common.field.common import InstanceField
from common.util.field_message import ErrMessage
from dataset.models import Paragraph
class IGenerateHumanMessageStep(IBaseChatPipelineStep):
@ -39,6 +39,8 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep):
prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
# 补齐问题
padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题"))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置"))
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer
@ -56,6 +58,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep):
max_paragraph_char_number: int,
prompt: str,
padding_problem_text: str = None,
no_references_setting=None,
**kwargs) -> List[BaseMessage]:
"""
@ -67,6 +70,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep):
:param prompt: 模板
:param padding_problem_text 用户修改文本
:param kwargs: 其他参数
:param no_references_setting: 无引用分段设置
:return:
"""
pass

View File

@ -6,7 +6,7 @@
@date2024/1/10 17:50
@desc:
"""
from typing import List
from typing import List, Dict
from langchain.schema import BaseMessage, HumanMessage
@ -26,22 +26,31 @@ class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
max_paragraph_char_number: int,
prompt: str,
padding_problem_text: str = None,
no_references_setting=None,
**kwargs) -> List[BaseMessage]:
prompt = prompt if no_references_setting.get('status') == 'designated_answer' else no_references_setting.get(
'value')
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
start_index = len(history_chat_record) - dialogue_number
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
return [*flat_map(history_message),
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list)]
self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list,
no_references_setting)]
@staticmethod
def to_human_message(prompt: str,
problem: str,
max_paragraph_char_number: int,
paragraph_list: List[ParagraphPipelineModel]):
paragraph_list: List[ParagraphPipelineModel],
no_references_setting: Dict):
if paragraph_list is None or len(paragraph_list) == 0:
return HumanMessage(content=prompt.format(**{'data': "<data></data>", 'question': problem}))
if no_references_setting.get('status') == 'ai_questioning':
return HumanMessage(
content=no_references_setting.get('value').format(**{'question': problem}))
else:
return HumanMessage(content=prompt.format(**{'data': "", 'question': problem}))
temp_data = ""
data_list = []
for p in paragraph_list:

View File

@ -19,7 +19,11 @@ from users.models import User
def get_dataset_setting_dict():
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding'}
return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding',
'no_references_setting': {
'status': 'ai_questioning',
'value': '{question}'
}}
def get_model_setting_dict():

View File

@ -73,6 +73,18 @@ class ApplicationSerializerModel(serializers.ModelSerializer):
fields = "__all__"
class NoReferencesChoices(models.TextChoices):
"""订单类型"""
ai_questioning = 'ai_questioning', 'ai回答'
designated_answer = 'designated_answer', '指定回答'
class NoReferencesSetting(serializers.Serializer):
status = serializers.ChoiceField(required=True, choices=NoReferencesChoices.choices,
error_messages=ErrMessage.char("无引用状态"))
value = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词"))
class DatasetSettingSerializer(serializers.Serializer):
top_n = serializers.FloatField(required=True, max_value=100, min_value=1,
error_messages=ErrMessage.float("引用分段数"))
@ -85,6 +97,8 @@ class DatasetSettingSerializer(serializers.Serializer):
message="类型只支持register|reset_password", code=500)
], error_messages=ErrMessage.char("检索模式"))
no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("未引用分段设置"))
class ModelSettingSerializer(serializers.Serializer):
prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词"))
@ -383,7 +397,9 @@ class ApplicationSerializer(serializers.Serializer):
application['multiple_rounds_dialogue'] = True if application.get('dialogue_number') > 0 else False
del application['dialogue_number']
if 'dataset_setting' in application:
application['dataset_setting'] = {**application['dataset_setting'], 'search_mode': 'embedding'}
application['dataset_setting'] = {'search_mode': 'embedding', 'no_references_setting': {
'status': 'ai_questioning',
'value': '{question}'}, **application['dataset_setting']}
return application
def page(self, current_page: int, page_size: int, with_valid=True):

View File

@ -78,7 +78,11 @@ class ChatInfo:
'problem_optimization': self.application.problem_optimization,
'stream': True,
'search_mode': self.application.dataset_setting.get(
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding'
'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding',
'no_references_setting': self.application.dataset_setting.get(
'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else {
'status': 'ai_questioning',
'value': '{question}'}
}

View File

@ -176,6 +176,18 @@ class ApplicationApi(ApiMixin):
description="最多引用字符数", default=3000),
'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title='检索模式',
description="embedding|keywords|blend", default='embedding'),
'no_references_setting': openapi.Schema(type=openapi.TYPE_OBJECT, title='检索模式',
required=['status', 'value'],
properties={
'status': openapi.Schema(type=openapi.TYPE_STRING,
title="状态",
description="ai作答:ai_questioning,指定回答:designated_answer",
default='ai_questioning'),
'value': openapi.Schema(type=openapi.TYPE_STRING,
title="",
description="ai作答:就是题词,指定回答:就是指定回答内容",
default='{question}'),
}),
}
)

View File

@ -335,3 +335,19 @@
.auto-tooltip-popper {
max-width: 500px;
}
// radio 一行一个样式
.radio-block {
width: 100%;
display: block;
.el-radio {
align-items: flex-start;
height: 100%;
width: 100%;
}
.el-radio__label {
width: 100%;
margin-top: -8px;
line-height: 30px;
}
}

View File

@ -1,6 +1,6 @@
<template>
<el-dialog title="设置 Logo" v-model="dialogVisible">
<el-radio-group v-model="radioType" class="card__block mb-16">
<el-radio-group v-model="radioType" class="radio-block mb-16">
<div>
<el-radio value="default">
<p>默认 Logo</p>
@ -14,7 +14,7 @@
/>
</el-radio>
</div>
<div>
<div class="mt-8">
<el-radio value="custom">
<p>自定义上传</p>
<div class="flex mt-8">
@ -126,16 +126,4 @@ function submit() {
defineExpose({ open })
</script>
<style lang="scss" scope>
.card__block {
width: 100%;
display: block;
.el-radio {
align-items: flex-start;
height: 100%;
}
.el-radio__inner {
margin-top: 3px;
}
}
</style>
<style lang="scss" scope></style>

View File

@ -322,7 +322,11 @@ const applicationForm = ref<ApplicationFormType>({
top_n: 3,
similarity: 0.6,
max_paragraph_char_number: 5000,
search_mode: 'embedding'
search_mode: 'embedding',
no_references_setting: {
status: 'ai_questioning',
value: '{question}'
}
},
model_setting: {
prompt: defaultPrompt

View File

@ -71,14 +71,65 @@
class="custom-slider"
/>
</el-form-item>
<el-form-item label="无引用知识库分段时">
<el-form
label-position="top"
ref="noReferencesformRef"
:model="noReferencesform"
:rules="noReferencesRules"
class="w-full"
:hide-required-asterisk="true"
>
<el-radio-group
v-model="form.no_references_setting.status"
class="radio-block mb-16"
>
<div>
<el-radio value="ai_questioning">
<p>继续向 AI 模型提问</p>
<el-form-item
v-if="form.no_references_setting.status === 'ai_questioning'"
label="提示词"
prop="ai_questioning"
>
<el-input
v-model="noReferencesform.ai_questioning"
:rows="2"
type="textarea"
maxlength="2048"
:placeholder="defaultValue['ai_questioning']"
/>
</el-form-item>
</el-radio>
</div>
<div class="mt-8">
<el-radio value="designated_answer">
<p>指定回答内容</p>
<el-form-item
v-if="form.no_references_setting.status === 'designated_answer'"
prop="designated_answer"
>
<el-input
v-model="noReferencesform.designated_answer"
:rows="2"
type="textarea"
maxlength="2048"
:placeholder="defaultValue['designated_answer']"
/>
</el-form-item>
</el-radio>
</div>
</el-radio-group>
</el-form>
</el-form-item>
</el-form>
</div>
</el-scrollbar>
</div>
<template #footer>
<span class="dialog-footer">
<span class="dialog-footer p-16">
<el-button @click.prevent="dialogVisible = false"> 取消 </el-button>
<el-button type="primary" @click="submit(paramFormRef)" :loading="loading">
<el-button type="primary" @click="submit(noReferencesformRef)" :loading="loading">
保存
</el-button>
</span>
@ -86,18 +137,40 @@
</el-dialog>
</template>
<script setup lang="ts">
import { ref, watch } from 'vue'
import { ref, watch, reactive } from 'vue'
import { cloneDeep } from 'lodash'
import type { FormInstance, FormRules } from 'element-plus'
const emit = defineEmits(['refresh'])
const paramFormRef = ref()
const noReferencesformRef = ref()
const defaultValue = {
ai_questioning: '{question}',
designated_answer:
'你好,我是 MaxKB 小助手,我的知识库只包含了 MaxKB 产品相关知识,请重新描述您的问题。'
}
const form = ref<any>({
search_mode: 'embedding',
top_n: 3,
similarity: 0.6,
max_paragraph_char_number: 5000
max_paragraph_char_number: 5000,
no_references_setting: {
status: 'ai_questioning',
value: '{question}'
}
})
const noReferencesform = ref<any>({
ai_questioning: defaultValue['ai_questioning'],
designated_answer: defaultValue['designated_answer']
})
const noReferencesRules = reactive<FormRules<any>>({
ai_questioning: [{ required: true, message: '请输入提示词', trigger: 'blur' }],
designated_answer: [{ required: true, message: '请输入内容', trigger: 'blur' }]
})
const dialogVisible = ref<boolean>(false)
@ -109,13 +182,24 @@ watch(dialogVisible, (bool) => {
search_mode: 'embedding',
top_n: 3,
similarity: 0.6,
max_paragraph_char_number: 5000
max_paragraph_char_number: 5000,
no_references_setting: {
status: 'ai_questioning',
value: ''
}
}
noReferencesform.value = {
ai_questioning: defaultValue['ai_questioning'],
designated_answer: defaultValue['designated_answer']
}
noReferencesformRef.value?.clearValidate()
}
})
const open = (data: any) => {
form.value = { ...form.value, ...data }
form.value = { ...form.value, ...cloneDeep(data) }
noReferencesform.value[form.value.no_references_setting.status] =
form.value.no_references_setting.value
dialogVisible.value = true
}
@ -123,6 +207,8 @@ const submit = async (formEl: FormInstance | undefined) => {
if (!formEl) return
await formEl.validate((valid, fields) => {
if (valid) {
form.value.no_references_setting.value =
noReferencesform.value[form.value.no_references_setting.status]
emit('refresh', form.value)
dialogVisible.value = false
}
@ -133,7 +219,7 @@ defineExpose({ open })
</script>
<style lang="scss" scope>
.param-dialog {
padding: 8px;
padding: 8px 8px 24px 8px;
.el-dialog__header {
padding: 16px 16px 0 16px;
}