mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 支持工作流ai对话节点添加节点上下文 (#1791)
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
(cherry picked from commit f65546a619)
This commit is contained in:
parent
960132af46
commit
882d577450
|
|
@ -26,6 +26,8 @@ class ChatNodeSerializer(serializers.Serializer):
|
|||
|
||||
model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置"))
|
||||
|
||||
dialogue_type = serializers.CharField(required=True, error_messages=ErrMessage.char("上下文类型"))
|
||||
|
||||
|
||||
class IChatNode(INode):
|
||||
type = 'ai-chat-node'
|
||||
|
|
@ -39,5 +41,6 @@ class IChatNode(INode):
|
|||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id,
|
||||
chat_record_id,
|
||||
model_params_setting=None,
|
||||
dialogue_type=None,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import List, Dict
|
|||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import BaseMessage, AIMessage
|
||||
|
||||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode
|
||||
|
|
@ -72,6 +72,22 @@ def get_default_model_params_setting(model_id):
|
|||
return model_params_setting
|
||||
|
||||
|
||||
def get_node_message(chat_record, runtime_node_id):
|
||||
node_details = chat_record.get_node_details_runtime_node_id(runtime_node_id)
|
||||
if node_details is None:
|
||||
return []
|
||||
return [HumanMessage(node_details.get('question')), AIMessage(node_details.get('answer'))]
|
||||
|
||||
|
||||
def get_workflow_message(chat_record):
|
||||
return [chat_record.get_human_message(), chat_record.get_ai_message()]
|
||||
|
||||
|
||||
def get_message(chat_record, dialogue_type, runtime_node_id):
|
||||
return get_node_message(chat_record, runtime_node_id) if dialogue_type == 'NODE' else get_workflow_message(
|
||||
chat_record)
|
||||
|
||||
|
||||
class BaseChatNode(IChatNode):
|
||||
def save_context(self, details, workflow_manage):
|
||||
self.context['answer'] = details.get('answer')
|
||||
|
|
@ -80,12 +96,17 @@ class BaseChatNode(IChatNode):
|
|||
|
||||
def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id,
|
||||
model_params_setting=None,
|
||||
dialogue_type=None,
|
||||
**kwargs) -> NodeResult:
|
||||
if dialogue_type is None:
|
||||
dialogue_type = 'WORKFLOW'
|
||||
|
||||
if model_params_setting is None:
|
||||
model_params_setting = get_default_model_params_setting(model_id)
|
||||
chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'),
|
||||
**model_params_setting)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number)
|
||||
history_message = self.get_history_message(history_chat_record, dialogue_number, dialogue_type,
|
||||
self.runtime_node_id)
|
||||
self.context['history_message'] = history_message
|
||||
question = self.generate_prompt_question(prompt)
|
||||
self.context['question'] = question.content
|
||||
|
|
@ -103,10 +124,10 @@ class BaseChatNode(IChatNode):
|
|||
_write_context=write_context)
|
||||
|
||||
@staticmethod
|
||||
def get_history_message(history_chat_record, dialogue_number):
|
||||
def get_history_message(history_chat_record, dialogue_number, dialogue_type, runtime_node_id):
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
history_message = reduce(lambda x, y: [*x, *y], [
|
||||
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
get_message(history_chat_record[index], dialogue_type, runtime_node_id)
|
||||
for index in
|
||||
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
|
||||
return history_message
|
||||
|
|
|
|||
|
|
@ -167,5 +167,8 @@ class ChatRecord(AppModelMixin):
|
|||
def get_ai_message(self):
|
||||
return AIMessage(content=self.answer_text)
|
||||
|
||||
def get_node_details_runtime_node_id(self, runtime_node_id):
|
||||
return self.details.get(runtime_node_id, None)
|
||||
|
||||
class Meta:
|
||||
db_table = "application_chat_record"
|
||||
|
|
|
|||
|
|
@ -93,9 +93,8 @@
|
|||
v-if="showAnchor"
|
||||
@mousemove.stop
|
||||
@mousedown.stop
|
||||
@keydown.stop
|
||||
@click.stop
|
||||
@wheel.stop
|
||||
@wheel="handleWheel"
|
||||
:show="showAnchor"
|
||||
:id="id"
|
||||
style="left: 100%; top: 50%; transform: translate(0, -50%)"
|
||||
|
|
@ -142,6 +141,12 @@ const showNode = computed({
|
|||
return true
|
||||
}
|
||||
})
|
||||
const handleWheel = (event: any) => {
|
||||
const isCombinationKeyPressed = event.ctrlKey || event.metaKey
|
||||
if (!isCombinationKeyPressed) {
|
||||
event.stopPropagation()
|
||||
}
|
||||
}
|
||||
const node_status = computed(() => {
|
||||
if (props.nodeModel.properties.status) {
|
||||
return props.nodeModel.properties.status
|
||||
|
|
|
|||
|
|
@ -148,6 +148,15 @@
|
|||
/>
|
||||
</el-form-item>
|
||||
<el-form-item label="历史聊天记录">
|
||||
<template #label>
|
||||
<div class="flex-between">
|
||||
<div>历史聊天记录</div>
|
||||
<el-select v-model="chat_data.dialogue_type" type="small" style="width: 100px">
|
||||
<el-option label="节点" value="NODE" />
|
||||
<el-option label="工作流" value="WORKFLOW" />
|
||||
</el-select>
|
||||
</div>
|
||||
</template>
|
||||
<el-input-number
|
||||
v-model="chat_data.dialogue_number"
|
||||
:min="0"
|
||||
|
|
@ -246,7 +255,8 @@ const form = {
|
|||
dialogue_number: 1,
|
||||
is_result: false,
|
||||
temperature: null,
|
||||
max_tokens: null
|
||||
max_tokens: null,
|
||||
dialogue_type: 'WORKFLOW'
|
||||
}
|
||||
|
||||
const chat_data = computed({
|
||||
|
|
|
|||
Loading…
Reference in New Issue