feat: 工作流
Some checks are pending
sync2gitee / repo-sync (push) Waiting to run
Typos Check / Spell Check with Typos (push) Waiting to run

This commit is contained in:
shaohuzhang1 2024-06-17 13:52:12 +08:00
parent 029b39ae80
commit 88c97588db
6 changed files with 20 additions and 15 deletions

View File

@ -103,9 +103,9 @@ class FlowParamsSerializer(serializers.Serializer):
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.base("流式输出"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型"))
class INode:

View File

@ -127,13 +127,13 @@ class BaseChatNode(IChatNode):
if stream:
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question}, {},
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream,
_to_response=to_stream_response)
else:
r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question}, {},
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context, _to_response=to_response)
@staticmethod

View File

@ -126,13 +126,13 @@ class BaseQuestionNode(IQuestionNode):
r = chat_model.stream(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'get_to_response_write_context': get_to_response_write_context,
'history_message': history_message, 'question': question}, {},
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context_stream,
_to_response=to_stream_response)
else:
r = chat_model.invoke(message_list)
return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list,
'history_message': history_message, 'question': question}, {},
'history_message': history_message, 'question': question.content}, {},
_write_context=write_context, _to_response=to_response)
@staticmethod

View File

@ -143,13 +143,18 @@ class WorkflowManage:
@param prompt: 提示词信息
@return: 格式化后的提示词
"""
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
context = {
'global': self.context,
}
for node in self.node_context:
fields = node.node.properties.get('fields')
if fields is not None:
for field in fields:
prompt = prompt.replace(field.get('globeLabel'), field.get('globeValue'))
context[node.id] = node.context
prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2')
value = prompt_template.format(context=context)
return value

View File

@ -231,7 +231,7 @@ class ChatMessageSerializer(serializers.Serializer):
stream = self.data.get('stream')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
work_flow_manage = WorkflowManage(Flow.new_instance(json.loads(chat_info.application.work_flow)),
work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.application.work_flow),
{'history_chat_record': chat_info.chat_record_list, 'question': message,
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
'stream': stream,
@ -241,7 +241,7 @@ class ChatMessageSerializer(serializers.Serializer):
def chat(self):
super().is_valid(raise_exception=True)
application = QuerySet(Application).filter(self.data.get('application_id'))
application = QuerySet(Application).filter(id=self.data.get('application_id')).first()
if application.type == ApplicationTypeChoices.SIMPLE:
chat_info = self.is_valid_application_simple(raise_exception=True)
return self.chat_simple(chat_info)

View File

@ -38,12 +38,12 @@ class AppNode extends HtmlNode {
if (filterNodes.length - 1 > 0) {
props.model.properties.stepName = props.model.properties.stepName + (filterNodes.length - 1)
}
if (props.model.properties?.fields?.length > 0) {
props.model.properties.fields.map((item: any) => {
item['globeLabel'] = `{{${props.model.properties.stepName}.${item.value}}}`
item['globeValue'] = `{{context['${props.model.id}'].${item.value}}}`
})
}
}
if (props.model.properties?.fields?.length > 0) {
props.model.properties.fields.map((item: any) => {
item['globeLabel'] = `{{${props.model.properties.stepName}.${item.value}}}`
item['globeValue'] = `{{context['${props.model.id}'].${item.value}}}`
})
}
}