mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
refactor: 处理历史会话中图片的问题
This commit is contained in:
parent
b406a8954e
commit
832d2a2e7f
|
|
@ -15,7 +15,7 @@ class BaseDocumentExtractNode(IDocumentExtractNode):
|
|||
|
||||
self.context['document_list'] = document
|
||||
content = ''
|
||||
spliter = '\n-----------------------------------\n'
|
||||
splitter = '\n-----------------------------------\n'
|
||||
if document is None:
|
||||
return NodeResult({'content': content}, {})
|
||||
|
||||
|
|
@ -29,7 +29,7 @@ class BaseDocumentExtractNode(IDocumentExtractNode):
|
|||
# 回到文件头
|
||||
buffer.seek(0)
|
||||
file_content = split_handle.get_content(buffer)
|
||||
content += spliter + '## ' + doc['name'] + '\n' + file_content
|
||||
content += splitter + '## ' + doc['name'] + '\n' + file_content
|
||||
break
|
||||
|
||||
return NodeResult({'content': content}, {})
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ class ImageUnderstandNodeSerializer(serializers.Serializer):
|
|||
# 多轮对话数量
|
||||
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
|
||||
|
||||
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型"))
|
||||
|
||||
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
|
||||
|
||||
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
|
||||
|
|
@ -32,7 +34,7 @@ class IImageUnderstandNode(INode):
|
|||
self.node_params_serializer.data.get('image_list')[1:])
|
||||
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
|
||||
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
|
||||
chat_record_id,
|
||||
image,
|
||||
**kwargs) -> NodeResult:
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
self.context['question'] = details.get('question')
|
||||
self.answer_text = details.get('answer')
|
||||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, chat_record_id,
|
||||
image,
|
||||
**kwargs) -> NodeResult:
|
||||
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
|
||||
|
|
@ -71,10 +71,10 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
self.context['question'] = question.content
|
||||
# todo 处理上传图片
|
||||
message_list = self.generate_message_list(image_model, system, prompt, history_message, image)
|
||||
self.context['message_list'] = message_list
|
||||
self.context['image_list'] = image
|
||||
self.context['dialogue_type'] = dialogue_type
|
||||
if stream:
|
||||
r = image_model.stream(message_list)
|
||||
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
|
||||
|
|
@ -86,15 +86,31 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
'history_message': history_message, 'question': question.content}, {},
|
||||
_write_context=write_context)
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number):
|
||||
def get_history_message(self, history_chat_record, dialogue_number):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
[self.generate_history_human_message(history_chat_record[index]), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||
return history_message
|
||||
|
||||
def generate_history_human_message(self, chat_record):
|
||||
|
||||
for data in chat_record.details.values():
|
||||
if self.node.id == data['node_id'] and 'image_list' in data:
|
||||
image_list = data['image_list']
|
||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
file_id = image_list[0]['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
|
||||
return HumanMessage(
|
||||
content=[
|
||||
{'type': 'text', 'text': data['question']},
|
||||
{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}},
|
||||
])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
def generate_prompt_question(self, prompt):
|
||||
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
|
||||
|
||||
|
|
@ -148,5 +164,6 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
'answer_tokens': self.context.get('answer_tokens'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message,
|
||||
'image_list': self.context.get('image_list')
|
||||
'image_list': self.context.get('image_list'),
|
||||
'dialogue_type': self.context.get('dialogue_type')
|
||||
}
|
||||
|
|
|
|||
|
|
@ -442,7 +442,8 @@ class ChatView(APIView):
|
|||
def post(self, request: Request, application_id: str, chat_id: str):
|
||||
files = request.FILES.getlist('file')
|
||||
file_ids = []
|
||||
meta = {'application_id': application_id, 'chat_id': chat_id}
|
||||
debug = request.data.get("debug", "false").lower() == "true"
|
||||
meta = {'application_id': application_id, 'chat_id': chat_id, 'debug': debug}
|
||||
for file in files:
|
||||
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
|
||||
file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]})
|
||||
|
|
|
|||
|
|
@ -8,8 +8,10 @@
|
|||
"""
|
||||
from .client_access_num_job import *
|
||||
from .clean_chat_job import *
|
||||
from .clean_debug_file_job import *
|
||||
|
||||
|
||||
def run():
|
||||
client_access_num_job.run()
|
||||
clean_chat_job.run()
|
||||
clean_debug_file_job.run()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
# coding=utf-8
|
||||
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from django.db.models import Q
|
||||
from django.utils import timezone
|
||||
from django_apscheduler.jobstores import DjangoJobStore
|
||||
|
||||
from common.lock.impl.file_lock import FileLock
|
||||
from dataset.models import File
|
||||
|
||||
scheduler = BackgroundScheduler()
|
||||
scheduler.add_jobstore(DjangoJobStore(), "default")
|
||||
lock = FileLock()
|
||||
|
||||
|
||||
def clean_debug_file():
|
||||
logging.getLogger("max_kb").info('开始清理debug文件')
|
||||
two_hours_ago = timezone.now() - timedelta(hours=2)
|
||||
# 删除对应的文件
|
||||
File.objects.filter(Q(create_time__lt=two_hours_ago) & Q(meta__debug=True)).delete()
|
||||
logging.getLogger("max_kb").info('结束清理debug文件')
|
||||
|
||||
|
||||
def run():
|
||||
if lock.try_lock('clean_debug_file', 30 * 30):
|
||||
try:
|
||||
scheduler.start()
|
||||
clean_debug_file_job = scheduler.get_job(job_id='clean_debug_file')
|
||||
if clean_debug_file_job is not None:
|
||||
clean_debug_file_job.remove()
|
||||
scheduler.add_job(clean_debug_file, 'cron', hour='2', minute='0', second='0', id='clean_debug_file')
|
||||
finally:
|
||||
lock.un_lock('clean_debug_file')
|
||||
|
|
@ -100,6 +100,7 @@ const props = withDefaults(
|
|||
appId?: string
|
||||
chatId: string
|
||||
sendMessage: (question: string, other_params_data?: any, chat?: chatType) => void
|
||||
openChatId: () => Promise<string>
|
||||
}>(),
|
||||
{
|
||||
applicationDetails: () => ({}),
|
||||
|
|
@ -165,8 +166,14 @@ const uploadFile = async (file: any, fileList: any) => {
|
|||
}
|
||||
|
||||
if (!chatId_context.value) {
|
||||
const res = await applicationApi.getChatOpen(props.applicationDetails.id as string)
|
||||
chatId_context.value = res.data
|
||||
const res = await props.openChatId()
|
||||
chatId_context.value = res
|
||||
}
|
||||
|
||||
if (props.type === 'debug-ai-chat') {
|
||||
formData.append('debug', 'true')
|
||||
} else {
|
||||
formData.append('debug', 'false')
|
||||
}
|
||||
|
||||
applicationApi
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@
|
|||
:is-mobile="isMobile"
|
||||
:type="type"
|
||||
:send-message="sendMessage"
|
||||
:open-chat-id="openChatId"
|
||||
v-model:chat-id="chartOpenId"
|
||||
v-model:loading="loading"
|
||||
v-if="type !== 'log'"
|
||||
|
|
|
|||
|
|
@ -128,7 +128,16 @@
|
|||
@submitDialog="submitDialog"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="历史聊天记录">
|
||||
<el-form-item>
|
||||
<template #label>
|
||||
<div class="flex-between">
|
||||
<div>历史聊天记录</div>
|
||||
<el-select v-model="form_data.dialogue_type" class="w-120">
|
||||
<el-option label="节点" value="NODE"/>
|
||||
<el-option label="工作流" value="WORKFLOW"/>
|
||||
</el-select>
|
||||
</div>
|
||||
</template>
|
||||
<el-input-number
|
||||
v-model="form_data.dialogue_number"
|
||||
:min="0"
|
||||
|
|
@ -213,6 +222,7 @@ const form = {
|
|||
system: '',
|
||||
prompt: defaultPrompt,
|
||||
dialogue_number: 0,
|
||||
dialogue_type: 'NODE',
|
||||
is_result: true,
|
||||
temperature: null,
|
||||
max_tokens: null,
|
||||
|
|
|
|||
Loading…
Reference in New Issue