refactor: 处理历史会话中图片的问题
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run

This commit is contained in:
CaptainB 2024-11-14 15:28:15 +08:00 committed by 刘瑞斌
parent b406a8954e
commit 832d2a2e7f
9 changed files with 89 additions and 13 deletions

View File

@ -15,7 +15,7 @@ class BaseDocumentExtractNode(IDocumentExtractNode):
self.context['document_list'] = document
content = ''
spliter = '\n-----------------------------------\n'
splitter = '\n-----------------------------------\n'
if document is None:
return NodeResult({'content': content}, {})
@ -29,7 +29,7 @@ class BaseDocumentExtractNode(IDocumentExtractNode):
# 回到文件头
buffer.seek(0)
file_content = split_handle.get_content(buffer)
content += spliter + '## ' + doc['name'] + '\n' + file_content
content += splitter + '## ' + doc['name'] + '\n' + file_content
break
return NodeResult({'content': content}, {})

View File

@ -16,6 +16,8 @@ class ImageUnderstandNodeSerializer(serializers.Serializer):
# 多轮对话数量
dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量"))
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("对话存储类型"))
is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容'))
image_list = serializers.ListField(required=False, error_messages=ErrMessage.list("图片"))
@ -32,7 +34,7 @@ class IImageUnderstandNode(INode):
self.node_params_serializer.data.get('image_list')[1:])
return self.execute(image=res, **self.node_params_serializer.data, **self.flow_params_serializer.data)
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id,
chat_record_id,
image,
**kwargs) -> NodeResult:

View File

@ -63,7 +63,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
self.context['question'] = details.get('question')
self.answer_text = details.get('answer')
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
def execute(self, model_id, system, prompt, dialogue_number, dialogue_type, history_chat_record, stream, chat_id, chat_record_id,
image,
**kwargs) -> NodeResult:
image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'))
@ -71,10 +71,10 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
self.context['history_message'] = history_message
question = self.generate_prompt_question(prompt)
self.context['question'] = question.content
# todo 处理上传图片
message_list = self.generate_message_list(image_model, system, prompt, history_message, image)
self.context['message_list'] = message_list
self.context['image_list'] = image
self.context['dialogue_type'] = dialogue_type
if stream:
r = image_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': image_model, 'message_list': message_list,
@ -86,15 +86,31 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context)
@staticmethod
def get_history_message(history_chat_record, dialogue_number):
def get_history_message(self, history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = reduce(lambda x, y: [*x, *y], [
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
[self.generate_history_human_message(history_chat_record[index]), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message
def generate_history_human_message(self, chat_record):
for data in chat_record.details.values():
if self.node.id == data['node_id'] and 'image_list' in data:
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
file_id = image_list[0]['file_id']
file = QuerySet(File).filter(id=file_id).first()
base64_image = base64.b64encode(file.get_byte()).decode("utf-8")
return HumanMessage(
content=[
{'type': 'text', 'text': data['question']},
{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}},
])
return HumanMessage(content=chat_record.problem_text)
def generate_prompt_question(self, prompt):
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
@ -148,5 +164,6 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
'answer_tokens': self.context.get('answer_tokens'),
'status': self.status,
'err_message': self.err_message,
'image_list': self.context.get('image_list')
'image_list': self.context.get('image_list'),
'dialogue_type': self.context.get('dialogue_type')
}

View File

@ -442,7 +442,8 @@ class ChatView(APIView):
def post(self, request: Request, application_id: str, chat_id: str):
files = request.FILES.getlist('file')
file_ids = []
meta = {'application_id': application_id, 'chat_id': chat_id}
debug = request.data.get("debug", "false").lower() == "true"
meta = {'application_id': application_id, 'chat_id': chat_id, 'debug': debug}
for file in files:
file_url = FileSerializer(data={'file': file, 'meta': meta}).upload()
file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]})

View File

@ -8,8 +8,10 @@
"""
from .client_access_num_job import *
from .clean_chat_job import *
from .clean_debug_file_job import *
def run():
client_access_num_job.run()
clean_chat_job.run()
clean_debug_file_job.run()

View File

@ -0,0 +1,36 @@
# coding=utf-8
import logging
from datetime import timedelta
from apscheduler.schedulers.background import BackgroundScheduler
from django.db.models import Q
from django.utils import timezone
from django_apscheduler.jobstores import DjangoJobStore
from common.lock.impl.file_lock import FileLock
from dataset.models import File
scheduler = BackgroundScheduler()
scheduler.add_jobstore(DjangoJobStore(), "default")
lock = FileLock()
def clean_debug_file():
logging.getLogger("max_kb").info('开始清理debug文件')
two_hours_ago = timezone.now() - timedelta(hours=2)
# 删除对应的文件
File.objects.filter(Q(create_time__lt=two_hours_ago) & Q(meta__debug=True)).delete()
logging.getLogger("max_kb").info('结束清理debug文件')
def run():
if lock.try_lock('clean_debug_file', 30 * 30):
try:
scheduler.start()
clean_debug_file_job = scheduler.get_job(job_id='clean_debug_file')
if clean_debug_file_job is not None:
clean_debug_file_job.remove()
scheduler.add_job(clean_debug_file, 'cron', hour='2', minute='0', second='0', id='clean_debug_file')
finally:
lock.un_lock('clean_debug_file')

View File

@ -100,6 +100,7 @@ const props = withDefaults(
appId?: string
chatId: string
sendMessage: (question: string, other_params_data?: any, chat?: chatType) => void
openChatId: () => Promise<string>
}>(),
{
applicationDetails: () => ({}),
@ -165,8 +166,14 @@ const uploadFile = async (file: any, fileList: any) => {
}
if (!chatId_context.value) {
const res = await applicationApi.getChatOpen(props.applicationDetails.id as string)
chatId_context.value = res.data
const res = await props.openChatId()
chatId_context.value = res
}
if (props.type === 'debug-ai-chat') {
formData.append('debug', 'true')
} else {
formData.append('debug', 'false')
}
applicationApi

View File

@ -38,6 +38,7 @@
:is-mobile="isMobile"
:type="type"
:send-message="sendMessage"
:open-chat-id="openChatId"
v-model:chat-id="chartOpenId"
v-model:loading="loading"
v-if="type !== 'log'"

View File

@ -128,7 +128,16 @@
@submitDialog="submitDialog"
/>
</el-form-item>
<el-form-item label="历史聊天记录">
<el-form-item>
<template #label>
<div class="flex-between">
<div>历史聊天记录</div>
<el-select v-model="form_data.dialogue_type" class="w-120">
<el-option label="节点" value="NODE"/>
<el-option label="工作流" value="WORKFLOW"/>
</el-select>
</div>
</template>
<el-input-number
v-model="form_data.dialogue_number"
:min="0"
@ -213,6 +222,7 @@ const form = {
system: '',
prompt: defaultPrompt,
dialogue_number: 0,
dialogue_type: 'NODE',
is_result: true,
temperature: null,
max_tokens: null,