feat: 工作流

This commit is contained in:
shaohuzhang1 2024-06-20 15:23:07 +08:00
parent f36b494c6a
commit 33d02ed146
3 changed files with 6 additions and 105 deletions

View File

@ -7,6 +7,7 @@
@desc:
"""
import json
from functools import reduce
from typing import List, Dict
from django.db.models import QuerySet
@ -139,16 +140,17 @@ class BaseChatNode(IChatNode):
@staticmethod
def get_history_message(history_chat_record, dialogue_number):
start_index = len(history_chat_record) - dialogue_number
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))]
history_message = reduce(lambda x, y: [*x, *y], [
[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
range(start_index if start_index > 0 else 0, len(history_chat_record))], [])
return history_message
def generate_prompt_question(self, prompt):
return HumanMessage(self.workflow_manage.generate_prompt(prompt))
def generate_message_list(self, system: str, prompt: str, history_message):
if system is None or len(system) == 0:
if system is not None and len(system) > 0:
return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message,
HumanMessage(self.workflow_manage.generate_prompt(prompt))]
else:

View File

@ -31,51 +31,6 @@ class ConditionNodeParamsSerializer(serializers.Serializer):
branch = ConditionBranchSerializer(many=True)
j = """
{ "branch": [
{
"conditions": [
{
"field": [
"34902d3d-a3ff-497f-b8e1-0c34a44d7dd5",
"paragraph_list"
],
"compare": "len_eq",
"value": "0"
}
],
"id": "2391",
"condition": "and"
},
{
"conditions": [
{
"field": [
"34902d3d-a3ff-497f-b8e1-0c34a44d7dd5",
"paragraph_list"
],
"compare": "len_eq",
"value": "1"
}
],
"id": "1143",
"condition": "and"
},
{
"conditions": [
],
"id": "9208",
"condition": "and"
}
]}
"""
a = json.loads(j)
c = ConditionNodeParamsSerializer(data=a)
c.is_valid(raise_exception=True)
print(c.data)
class IConditionNode(INode):
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return ConditionNodeParamsSerializer

View File

@ -30,59 +30,3 @@ class QianfanChatModel(QianfanChatEndpoint):
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
def stream(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
if len(input) % 2 == 0:
input = [HumanMessage(content='padding'), *input]
input = [
HumanMessage(content=input[index].content) if index % 2 == 0 else AIMessage(content=input[index].content)
for index in range(0, len(input))]
if type(self)._stream == BaseChatModel._stream:
# model doesn't implement streaming, so use default implementation
yield cast(
BaseMessageChunk, self.invoke(input, config=config, stop=stop, **kwargs)
)
else:
config = config or {}
messages = self._convert_input(input).to_messages()
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs}
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = callback_manager.on_chat_model_start(
dumpd(self),
[messages],
invocation_params=params,
options=options,
name=config.get("run_name"),
)
try:
generation: Optional[ChatGenerationChunk] = None
for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.message
if generation is None:
generation = chunk
assert generation is not None
except BaseException as e:
run_manager.on_llm_error(e)
raise e
else:
run_manager.on_llm_end(
LLMResult(generations=[[generation]]),
)