mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
perf: 应用的AI模型修改为不必填 (#297)
This commit is contained in:
parent
5705f3c4a8
commit
7b5ccd9089
|
|
@ -54,7 +54,7 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
message_list = serializers.ListField(required=True, child=MessageField(required=True),
|
||||
error_messages=ErrMessage.list("对话列表"))
|
||||
# 大语言模型
|
||||
chat_model = ModelField(error_messages=ErrMessage.list("大语言模型"))
|
||||
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型"))
|
||||
# 段落列表
|
||||
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
|
||||
# 对话id
|
||||
|
|
|
|||
|
|
@ -126,6 +126,26 @@ class BaseChatStep(IChatStep):
|
|||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_stream_result(message_list: List[BaseMessage],
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
no_references_setting=None):
|
||||
if paragraph_list is None:
|
||||
paragraph_list = []
|
||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
|
||||
for paragraph in paragraph_list if
|
||||
paragraph.hit_handling_method == 'directly_return']
|
||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||
return iter(directly_return_chunk_list), False
|
||||
elif no_references_setting.get(
|
||||
'status') == 'designated_answer':
|
||||
return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False
|
||||
if chat_model is None:
|
||||
return iter([AIMessageChunk('抱歉,没有在知识库中查询到相关信息。')]), False
|
||||
else:
|
||||
return chat_model.stream(message_list), True
|
||||
|
||||
def execute_stream(self, message_list: List[BaseMessage],
|
||||
chat_id,
|
||||
problem_text,
|
||||
|
|
@ -136,29 +156,8 @@ class BaseChatStep(IChatStep):
|
|||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None,
|
||||
no_references_setting=None):
|
||||
is_ai_chat = False
|
||||
# 调用模型
|
||||
if chat_model is None:
|
||||
chat_result = iter(
|
||||
[AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
|
||||
else:
|
||||
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
|
||||
'status') == 'designated_answer':
|
||||
chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))])
|
||||
else:
|
||||
if paragraph_list is not None and len(paragraph_list) > 0:
|
||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
|
||||
for paragraph in paragraph_list if
|
||||
paragraph.hit_handling_method == 'directly_return']
|
||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||
chat_result = iter(directly_return_chunk_list)
|
||||
else:
|
||||
chat_result = chat_model.stream(message_list)
|
||||
is_ai_chat = True
|
||||
else:
|
||||
chat_result = chat_model.stream(message_list)
|
||||
is_ai_chat = True
|
||||
|
||||
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
|
||||
no_references_setting)
|
||||
chat_record_id = uuid.uuid1()
|
||||
r = StreamingHttpResponse(
|
||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||
|
|
@ -169,6 +168,27 @@ class BaseChatStep(IChatStep):
|
|||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def get_block_result(message_list: List[BaseMessage],
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
no_references_setting=None):
|
||||
if paragraph_list is None:
|
||||
paragraph_list = []
|
||||
|
||||
directly_return_chunk_list = [AIMessage(content=paragraph.content)
|
||||
for paragraph in paragraph_list if
|
||||
paragraph.hit_handling_method == 'directly_return']
|
||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||
return directly_return_chunk_list[0], False
|
||||
elif no_references_setting.get(
|
||||
'status') == 'designated_answer':
|
||||
return AIMessage(no_references_setting.get('value')), False
|
||||
if chat_model is None:
|
||||
return AIMessage('抱歉,没有在知识库中查询到相关信息。'), False
|
||||
else:
|
||||
return chat_model.invoke(message_list), True
|
||||
|
||||
def execute_block(self, message_list: List[BaseMessage],
|
||||
chat_id,
|
||||
problem_text,
|
||||
|
|
@ -178,28 +198,8 @@ class BaseChatStep(IChatStep):
|
|||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None, no_references_setting=None):
|
||||
is_ai_chat = False
|
||||
# 调用模型
|
||||
if chat_model is None:
|
||||
chat_result = AIMessage(
|
||||
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
|
||||
else:
|
||||
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
|
||||
'status') == 'designated_answer':
|
||||
chat_result = AIMessage(content=no_references_setting.get('value'))
|
||||
else:
|
||||
if paragraph_list is not None and len(paragraph_list) > 0:
|
||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
|
||||
for paragraph in paragraph_list if
|
||||
paragraph.hit_handling_method == 'directly_return']
|
||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||
chat_result = iter(directly_return_chunk_list)
|
||||
else:
|
||||
chat_result = chat_model.invoke(message_list)
|
||||
is_ai_chat = True
|
||||
else:
|
||||
chat_result = chat_model.invoke(message_list)
|
||||
is_ai_chat = True
|
||||
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, no_references_setting)
|
||||
chat_record_id = uuid.uuid1()
|
||||
if is_ai_chat:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class IResetProblemStep(IBaseChatPipelineStep):
|
|||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||
error_messages=ErrMessage.list("历史对答"))
|
||||
# 大语言模型
|
||||
chat_model = ModelField(error_messages=ErrMessage.base("大语言模型"))
|
||||
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型"))
|
||||
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ prompt = (
|
|||
class BaseResetProblemStep(IResetProblemStep):
|
||||
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
|
||||
**kwargs) -> str:
|
||||
if chat_model is None:
|
||||
return problem_text
|
||||
start_index = len(history_chat_record) - 3
|
||||
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
|
|
|
|||
|
|
@ -47,7 +47,8 @@ chat_cache = cache.caches['chat_cache']
|
|||
|
||||
class ModelDatasetAssociation(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
|
||||
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.char("模型id"))
|
||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
|
||||
error_messages=ErrMessage.uuid(
|
||||
"知识库id")),
|
||||
|
|
@ -57,8 +58,9 @@ class ModelDatasetAssociation(serializers.Serializer):
|
|||
super().is_valid(raise_exception=True)
|
||||
model_id = self.data.get('model_id')
|
||||
user_id = self.data.get('user_id')
|
||||
if not QuerySet(Model).filter(id=model_id).exists():
|
||||
raise AppApiException(500, f'模型不存在【{model_id}】')
|
||||
if model_id is not None and len(model_id) > 0:
|
||||
if not QuerySet(Model).filter(id=model_id).exists():
|
||||
raise AppApiException(500, f'模型不存在【{model_id}】')
|
||||
dataset_id_list = list(set(self.data.get('dataset_id_list')))
|
||||
exist_dataset_id_list = [str(dataset.id) for dataset in
|
||||
QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)]
|
||||
|
|
@ -109,7 +111,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
max_length=256, min_length=1,
|
||||
error_messages=ErrMessage.char("应用描述"))
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型"))
|
||||
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.char("模型"))
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
|
||||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
|
||||
error_messages=ErrMessage.char("开场白"))
|
||||
|
|
@ -254,7 +257,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
error_messages=ErrMessage.char("应用名称"))
|
||||
desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.char("应用描述"))
|
||||
model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型"))
|
||||
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char("模型"))
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=False,
|
||||
error_messages=ErrMessage.boolean("多轮会话"))
|
||||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
|
||||
|
|
@ -494,13 +498,14 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
application_id = self.data.get("application_id")
|
||||
|
||||
application = QuerySet(Application).get(id=application_id)
|
||||
|
||||
model = QuerySet(Model).filter(
|
||||
id=instance.get('model_id') if 'model_id' in instance else application.model_id,
|
||||
user_id=application.user_id).first()
|
||||
if model is None:
|
||||
raise AppApiException(500, "模型不存在")
|
||||
|
||||
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
|
||||
application.model_id = None
|
||||
else:
|
||||
model = QuerySet(Model).filter(
|
||||
id=instance.get('model_id'),
|
||||
user_id=application.user_id).first()
|
||||
if model is None:
|
||||
raise AppApiException(500, "模型不存在")
|
||||
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
|
||||
'dataset_setting', 'model_setting', 'problem_optimization',
|
||||
'api_key_is_active', 'icon']
|
||||
|
|
|
|||
|
|
@ -167,9 +167,11 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
chat_cache.set(chat_id,
|
||||
chat_info, timeout=60 * 30)
|
||||
model = chat_info.application.model
|
||||
if model is None:
|
||||
return chat_info
|
||||
model = QuerySet(Model).filter(id=model.id).first()
|
||||
if model is None:
|
||||
raise AppApiException(500, "模型不存在")
|
||||
return chat_info
|
||||
if model.status == Status.ERROR:
|
||||
raise AppApiException(500, "当前模型不可用")
|
||||
if model.status == Status.DOWNLOAD:
|
||||
|
|
|
|||
|
|
@ -213,7 +213,8 @@ class ChatSerializers(serializers.Serializer):
|
|||
|
||||
id = serializers.UUIDField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.uuid("应用id"))
|
||||
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
||||
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.uuid("模型id"))
|
||||
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=True,
|
||||
error_messages=ErrMessage.boolean("多轮会话"))
|
||||
|
|
@ -246,14 +247,17 @@ class ChatSerializers(serializers.Serializer):
|
|||
def open(self):
|
||||
user_id = self.is_valid(raise_exception=True)
|
||||
chat_id = str(uuid.uuid1())
|
||||
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
|
||||
if model is None:
|
||||
raise AppApiException(500, "模型不存在")
|
||||
model_id = self.data.get('model_id')
|
||||
if model_id is not None and len(model_id) > 0:
|
||||
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
streaming=True)
|
||||
else:
|
||||
model = None
|
||||
chat_model = None
|
||||
dataset_id_list = self.data.get('dataset_id_list')
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
streaming=True)
|
||||
application = Application(id=None, dialogue_number=3, model=model,
|
||||
dataset_setting=self.data.get('dataset_setting'),
|
||||
model_setting=self.data.get('model_setting'),
|
||||
|
|
|
|||
|
|
@ -224,7 +224,7 @@ const chartOpenId = ref('')
|
|||
const chatList = ref<any[]>([])
|
||||
|
||||
const isDisabledChart = computed(
|
||||
() => !(inputValue.value.trim() && (props.appId || (props.data?.name && props.data?.model_id)))
|
||||
() => !(inputValue.value.trim() && (props.appId || props.data?.name))
|
||||
)
|
||||
const isMdArray = (val: string) => val.match(/^-\s.*/m)
|
||||
const prologueList = computed(() => {
|
||||
|
|
@ -509,16 +509,14 @@ function regenerationChart(item: chatType) {
|
|||
}
|
||||
|
||||
function getSourceDetail(row: any) {
|
||||
logApi
|
||||
.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading)
|
||||
.then((res) => {
|
||||
const exclude_keys = ['answer_text', 'id']
|
||||
Object.keys(res.data).forEach((key) => {
|
||||
if (!exclude_keys.includes(key)) {
|
||||
row[key] = res.data[key]
|
||||
}
|
||||
})
|
||||
logApi.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading).then((res) => {
|
||||
const exclude_keys = ['answer_text', 'id']
|
||||
Object.keys(res.data).forEach((key) => {
|
||||
if (!exclude_keys.includes(key)) {
|
||||
row[key] = res.data[key]
|
||||
}
|
||||
})
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@
|
|||
<el-form-item label="AI 模型" prop="model_id">
|
||||
<template #label>
|
||||
<div class="flex-between">
|
||||
<span>AI 模型 <span class="danger">*</span></span>
|
||||
<span>AI 模型 </span>
|
||||
</div>
|
||||
</template>
|
||||
<el-select
|
||||
|
|
@ -56,6 +56,7 @@
|
|||
placeholder="请选择 AI 模型"
|
||||
class="w-full"
|
||||
popper-class="select-model"
|
||||
:clearable="true"
|
||||
>
|
||||
<el-option-group
|
||||
v-for="(value, label) in modelOptions"
|
||||
|
|
@ -338,7 +339,7 @@ const rules = reactive<FormRules<ApplicationFormType>>({
|
|||
name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }],
|
||||
model_id: [
|
||||
{
|
||||
required: true,
|
||||
required: false,
|
||||
message: '请选择模型',
|
||||
trigger: 'change'
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue