diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 496b1847a..9eab397f9 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -693,33 +693,43 @@ class WorkflowManage: else: return self.get_node_by_id(node_id).get_reference_field(fields) - def generate_prompt(self, prompt: str): - """ - 格式化生成提示词 - @param prompt: 提示词信息 - @return: 格式化后的提示词 - """ + def get_workflow_content(self): context = { 'global': self.context, } for node in self.node_context: - properties = node.node.properties + context[node.id] = node.context + return context + + def reset_prompt(self, prompt: str): + placeholder = "{}" + for node in self.flow.nodes: + properties = node.properties node_config = properties.get('config') if node_config is not None: fields = node_config.get('fields') if fields is not None: for field in fields: globeLabel = f"{properties.get('stepName')}.{field.get('value')}" - globeValue = f"context['{node.id}'].{field.get('value')}" + globeValue = f"context.get('{node.id}',{placeholder}).get('{field.get('value', '')}','')" prompt = prompt.replace(globeLabel, globeValue) global_fields = node_config.get('globalFields') if global_fields is not None: for field in global_fields: globeLabel = f"全局变量.{field.get('value')}" - globeValue = f"context['global'].{field.get('value')}" + globeValue = f"context.get('global').get('{field.get('value', '')}','')" prompt = prompt.replace(globeLabel, globeValue) - context[node.id] = node.context + return prompt + + def generate_prompt(self, prompt: str): + """ + 格式化生成提示词 + @param prompt: 提示词信息 + @return: 格式化后的提示词 + """ + context = self.get_workflow_content() + prompt = self.reset_prompt(prompt) prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2') value = prompt_template.format(context=context) return value