feat: 优化应用对话

This commit is contained in:
shaohuzhang1 2024-01-18 18:36:24 +08:00
parent af7c28868d
commit 04d3ec0524
3 changed files with 129 additions and 78 deletions

View File

@ -41,8 +41,6 @@ def event_content(response,
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': chunk.content, 'is_end': False}) + "\n\n"
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '', 'is_end': True}) + "\n\n"
# 获取token
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text)
@ -56,6 +54,8 @@ def event_content(response,
manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text)
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '', 'is_end': True}) + "\n\n"
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,

View File

@ -176,12 +176,22 @@ class ChatRecordSerializer(serializers.Serializer):
chat_record_id = serializers.UUIDField(required=True)
def get_chat_record(self):
chat_record_id = self.data.get('chat_record_id')
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
chat_record.id == uuid.UUID(chat_record_id)]
if chat_record_list is not None and len(chat_record_list):
return chat_record_list[-1]
return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
def one(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
chat_record_id = self.data.get('chat_record_id')
chat_id = self.data.get('chat_id')
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
chat_record = self.get_chat_record()
if chat_record is None:
raise AppApiException(500, "对话不存在")
dataset_list = []
paragraph_list = []
if len(chat_record.paragraph_id_list) > 0:
@ -200,7 +210,7 @@ class ChatRecordSerializer(serializers.Serializer):
return {
**ChatRecordSerializerModel(chat_record).data,
'padding_problem_text': chat_record.details.get(
'padding_problem_text': chat_record.details.get('problem_padding').get(
'padding_problem_text') if 'problem_padding' in chat_record.details else None,
'dataset_list': dataset_list,
'paragraph_list': paragraph_list}

View File

@ -160,7 +160,7 @@
</div>
</template>
<script setup lang="ts">
import { ref, nextTick, computed, watch } from 'vue'
import { ref, nextTick, computed, watch, reactive } from 'vue'
import { useRoute } from 'vue-router'
import LogOperationButton from './LogOperationButton.vue'
import OperationButton from './OperationButton.vue'
@ -172,6 +172,7 @@ import { randomId } from '@/utils/utils'
import useStore from '@/stores'
import MdRenderer from '@/components/markdown-renderer/MdRenderer.vue'
import { MdPreview } from 'md-editor-v3'
import { MsgError } from '@/utils/message'
defineOptions({ name: 'AiChat' })
const route = useRoute()
const {
@ -288,6 +289,83 @@ function getChartOpenId() {
})
}
}
/**
* 获取一个递归函数,处理流式数据
* @param chat 每一条对话记录
* @param reader 流数据
* @param stream 是否是流式数据
*/
const getWrite = (chat: any, reader: any, stream: boolean) => {
let tempResult = ''
/**
*
* @param done 是否结束
* @param value
*/
const write_stream = ({ done, value }: { done: boolean; value: any }) => {
try {
if (done) {
ChatManagement.close(chat.id)
return
}
const decoder = new TextDecoder('utf-8')
let str = decoder.decode(value, { stream: true })
// start chunk chunkdata:{xxx}\n\n data:{ -> xxx}\n\n fetchchunkdata: \n\n
tempResult += str
if (tempResult.endsWith('\n\n')) {
str = tempResult
tempResult = ''
} else {
return reader.read().then(write_stream)
}
// end
if (str && str.startsWith('data:')) {
const split = str.match(/data:.*}\n\n/g)
if (split) {
for (const index in split) {
const chunk = JSON?.parse(split[index].replace('data:', ''))
chat.record_id = chunk.id
const content = chunk?.content
if (content) {
ChatManagement.append(chat.id, content)
}
if (chunk.is_end) {
//
return Promise.resolve()
}
}
}
}
} catch (e) {
return Promise.reject(e)
}
return reader.read().then(write_stream)
}
/**
* 处理 json 响应
* @param param0
*/
const write_json = ({ done, value }: { done: boolean; value: any }) => {
if (done) {
const result_block = JSON.parse(tempResult)
if (result_block.code === 500) {
return Promise.reject(result_block.message)
} else {
if (result_block.content) {
ChatManagement.append(chat.id, result_block.content)
}
}
ChatManagement.close(chat.id)
return
}
if (value) {
const decoder = new TextDecoder('utf-8')
tempResult += decoder.decode(value)
}
return reader.read().then(write_json)
}
return stream ? write_stream : write_json
}
function chatMessage() {
loading.value = true
@ -295,9 +373,8 @@ function chatMessage() {
getChartOpenId()
} else {
const problem_text = inputValue.value
const id = randomId()
chatList.value.push({
id: id,
const chat = reactive({
id: randomId(),
problem_text: problem_text,
answer_text: '',
buffer: [],
@ -306,74 +383,37 @@ function chatMessage() {
record_id: '',
vote_status: '-1'
})
chatList.value.push(chat)
inputValue.value = ''
nextTick(() => {
//
scrollDiv.value.setScrollTop(Number.MAX_SAFE_INTEGER)
})
applicationApi.postChatMessage(chartOpenId.value, problem_text).then((response) => {
const row = chatList.value.find((item) => item.id === id)
if (row) {
ChatManagement.addChatRecord(row, 50, loading)
ChatManagement.write(id)
//
applicationApi
.postChatMessage(chartOpenId.value, problem_text)
.then((response) => {
ChatManagement.addChatRecord(chat, 50, loading)
ChatManagement.write(chat.id)
const reader = response.body.getReader()
let tempResult = ''
/*eslint no-constant-condition: ["error", { "checkLoops": false }]*/
const write = ({ done, value }: { done: boolean; value: any }) => {
try {
if (done) {
ChatManagement.close(id)
return
}
const decoder = new TextDecoder('utf-8')
let str = decoder.decode(value, { stream: true })
// start chunk chunkdata:{xxx}\n\n data:{ -> xxx}\n\n fetchchunkdata: \n\n
tempResult += str
if (tempResult.endsWith('\n\n')) {
str = tempResult
tempResult = ''
} else {
return reader.read().then(write)
}
// end
if (str && str.startsWith('data:')) {
const split = str.match(/data:.*}\n\n/g)
if (split) {
for (const index in split) {
const chunk = JSON?.parse(split[index].replace('data:', ''))
row.record_id = chunk.id
const content = chunk?.content
if (content) {
ChatManagement.append(id, content)
}
if (chunk.is_end) {
//
return Promise.resolve()
}
}
}
}
} catch (e) {
console.log(e)
// console
}
return reader.read().then(write)
}
reader
.read()
.then(write)
.then((ok: any) => {
getSourceDetail(row)
})
.finally((ok: any) => {
ChatManagement.close(id)
})
.catch((e: any) => {
ChatManagement.close(id)
})
}
})
//
const write = getWrite(
chat,
reader,
response.headers.get('Content-Type') !== 'application/json'
)
return reader.read().then(write)
})
.then(() => {
return getSourceDetail(chat)
})
.finally(() => {
ChatManagement.close(chat.id)
})
.catch((e: any) => {
MsgError(e)
ChatManagement.close(chat.id)
})
}
}
@ -382,12 +422,13 @@ function regenerationChart(item: chatType) {
chatMessage()
}
function getSourceDetail(row: chatType) {
logApi.getRecordDetail(id, row.id, row.record_id, loading).then((res) => {
const obj = { row, ...res.data }
const index = chatList.value.findIndex((v) => v.id === row.id)
chatList.value.splice(index, 1, obj)
function getSourceDetail(row: any) {
logApi.getRecordDetail(id, chartOpenId.value, row.record_id, loading).then((res) => {
Object.keys(res.data).forEach((key) => {
row[key] = res.data[key]
})
})
return true
}
/**