perf: 应用的AI模型修改为不必填 (#297)

This commit is contained in:
shaohuzhang1 2024-04-28 17:09:12 +08:00 committed by GitHub
parent 5705f3c4a8
commit 7b5ccd9089
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 91 additions and 79 deletions

View File

@ -54,7 +54,7 @@ class IChatStep(IBaseChatPipelineStep):
message_list = serializers.ListField(required=True, child=MessageField(required=True),
error_messages=ErrMessage.list("对话列表"))
# 大语言模型
chat_model = ModelField(error_messages=ErrMessage.list("大语言模型"))
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型"))
# 段落列表
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
# 对话id

View File

@ -126,6 +126,26 @@ class BaseChatStep(IChatStep):
result.append({'role': 'ai', 'content': answer_text})
return result
@staticmethod
def get_stream_result(message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
return iter(directly_return_chunk_list), False
elif no_references_setting.get(
'status') == 'designated_answer':
return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False
if chat_model is None:
return iter([AIMessageChunk('抱歉,没有在知识库中查询到相关信息。')]), False
else:
return chat_model.stream(message_list), True
def execute_stream(self, message_list: List[BaseMessage],
chat_id,
problem_text,
@ -136,29 +156,8 @@ class BaseChatStep(IChatStep):
padding_problem_text: str = None,
client_id=None, client_type=None,
no_references_setting=None):
is_ai_chat = False
# 调用模型
if chat_model is None:
chat_result = iter(
[AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
else:
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))])
else:
if paragraph_list is not None and len(paragraph_list) > 0:
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
no_references_setting)
chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
@ -169,6 +168,27 @@ class BaseChatStep(IChatStep):
r['Cache-Control'] = 'no-cache'
return r
@staticmethod
def get_block_result(message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessage(content=paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
return directly_return_chunk_list[0], False
elif no_references_setting.get(
'status') == 'designated_answer':
return AIMessage(no_references_setting.get('value')), False
if chat_model is None:
return AIMessage('抱歉,没有在知识库中查询到相关信息。'), False
else:
return chat_model.invoke(message_list), True
def execute_block(self, message_list: List[BaseMessage],
chat_id,
problem_text,
@ -178,28 +198,8 @@ class BaseChatStep(IChatStep):
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None, no_references_setting=None):
is_ai_chat = False
# 调用模型
if chat_model is None:
chat_result = AIMessage(
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
else:
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = AIMessage(content=no_references_setting.get('value'))
else:
if paragraph_list is not None and len(paragraph_list) > 0:
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, no_references_setting)
chat_record_id = uuid.uuid1()
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)

View File

@ -28,7 +28,7 @@ class IResetProblemStep(IBaseChatPipelineStep):
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
error_messages=ErrMessage.list("历史对答"))
# 大语言模型
chat_model = ModelField(error_messages=ErrMessage.base("大语言模型"))
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型"))
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer

View File

@ -22,6 +22,8 @@ prompt = (
class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
**kwargs) -> str:
if chat_model is None:
return problem_text
start_index = len(history_chat_record) - 3
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in

View File

@ -47,7 +47,8 @@ chat_cache = cache.caches['chat_cache']
class ModelDatasetAssociation(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("模型id"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid(
"知识库id")),
@ -57,8 +58,9 @@ class ModelDatasetAssociation(serializers.Serializer):
super().is_valid(raise_exception=True)
model_id = self.data.get('model_id')
user_id = self.data.get('user_id')
if not QuerySet(Model).filter(id=model_id).exists():
raise AppApiException(500, f'模型不存在【{model_id}')
if model_id is not None and len(model_id) > 0:
if not QuerySet(Model).filter(id=model_id).exists():
raise AppApiException(500, f'模型不存在【{model_id}')
dataset_id_list = list(set(self.data.get('dataset_id_list')))
exist_dataset_id_list = [str(dataset.id) for dataset in
QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)]
@ -109,7 +111,8 @@ class ApplicationSerializer(serializers.Serializer):
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
max_length=256, min_length=1,
error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
error_messages=ErrMessage.char("开场白"))
@ -254,7 +257,8 @@ class ApplicationSerializer(serializers.Serializer):
error_messages=ErrMessage.char("应用名称"))
desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型"))
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean("多轮会话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
@ -494,13 +498,14 @@ class ApplicationSerializer(serializers.Serializer):
application_id = self.data.get("application_id")
application = QuerySet(Application).get(id=application_id)
model = QuerySet(Model).filter(
id=instance.get('model_id') if 'model_id' in instance else application.model_id,
user_id=application.user_id).first()
if model is None:
raise AppApiException(500, "模型不存在")
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
application.model_id = None
else:
model = QuerySet(Model).filter(
id=instance.get('model_id'),
user_id=application.user_id).first()
if model is None:
raise AppApiException(500, "模型不存在")
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
'dataset_setting', 'model_setting', 'problem_optimization',
'api_key_is_active', 'icon']

View File

@ -167,9 +167,11 @@ class ChatMessageSerializer(serializers.Serializer):
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
model = chat_info.application.model
if model is None:
return chat_info
model = QuerySet(Model).filter(id=model.id).first()
if model is None:
raise AppApiException(500, "模型不存在")
return chat_info
if model.status == Status.ERROR:
raise AppApiException(500, "当前模型不可用")
if model.status == Status.DOWNLOAD:

View File

@ -213,7 +213,8 @@ class ChatSerializers(serializers.Serializer):
id = serializers.UUIDField(required=False, allow_null=True,
error_messages=ErrMessage.uuid("应用id"))
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.uuid("模型id"))
multiple_rounds_dialogue = serializers.BooleanField(required=True,
error_messages=ErrMessage.boolean("多轮会话"))
@ -246,14 +247,17 @@ class ChatSerializers(serializers.Serializer):
def open(self):
user_id = self.is_valid(raise_exception=True)
chat_id = str(uuid.uuid1())
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
if model is None:
raise AppApiException(500, "模型不存在")
model_id = self.data.get('model_id')
if model_id is not None and len(model_id) > 0:
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
else:
model = None
chat_model = None
dataset_id_list = self.data.get('dataset_id_list')
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
application = Application(id=None, dialogue_number=3, model=model,
dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'),

View File

@ -224,7 +224,7 @@ const chartOpenId = ref('')
const chatList = ref<any[]>([])
const isDisabledChart = computed(
() => !(inputValue.value.trim() && (props.appId || (props.data?.name && props.data?.model_id)))
() => !(inputValue.value.trim() && (props.appId || props.data?.name))
)
const isMdArray = (val: string) => val.match(/^-\s.*/m)
const prologueList = computed(() => {
@ -509,16 +509,14 @@ function regenerationChart(item: chatType) {
}
function getSourceDetail(row: any) {
logApi
.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading)
.then((res) => {
const exclude_keys = ['answer_text', 'id']
Object.keys(res.data).forEach((key) => {
if (!exclude_keys.includes(key)) {
row[key] = res.data[key]
}
})
logApi.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading).then((res) => {
const exclude_keys = ['answer_text', 'id']
Object.keys(res.data).forEach((key) => {
if (!exclude_keys.includes(key)) {
row[key] = res.data[key]
}
})
})
return true
}

View File

@ -48,7 +48,7 @@
<el-form-item label="AI 模型" prop="model_id">
<template #label>
<div class="flex-between">
<span>AI 模型 <span class="danger">*</span></span>
<span>AI 模型 </span>
</div>
</template>
<el-select
@ -56,6 +56,7 @@
placeholder="请选择 AI 模型"
class="w-full"
popper-class="select-model"
:clearable="true"
>
<el-option-group
v-for="(value, label) in modelOptions"
@ -338,7 +339,7 @@ const rules = reactive<FormRules<ApplicationFormType>>({
name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }],
model_id: [
{
required: true,
required: false,
message: '请选择模型',
trigger: 'change'
}