mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
parent
3c043944e0
commit
8fbcbf079a
|
|
@ -13,6 +13,7 @@ from django.core import validators
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.common import flat_map
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
|
|
@ -43,6 +44,13 @@ class SearchDatasetStepNodeSerializer(serializers.Serializer):
|
|||
super().is_valid(raise_exception=True)
|
||||
|
||||
|
||||
def get_paragraph_list(chat_record, node_id):
|
||||
return flat_map([chat_record.details[key].get('paragraph_list', []) for key in chat_record.details if
|
||||
(chat_record.details[
|
||||
key].get('type', '') == 'search-dataset-node') and chat_record.details[key].get(
|
||||
'paragraph_list', []) is not None and key == node_id])
|
||||
|
||||
|
||||
class ISearchDatasetStepNode(INode):
|
||||
type = 'search-dataset-node'
|
||||
|
||||
|
|
@ -53,7 +61,13 @@ class ISearchDatasetStepNode(INode):
|
|||
question = self.workflow_manage.get_reference_field(
|
||||
self.node_params_serializer.data.get('question_reference_address')[0],
|
||||
self.node_params_serializer.data.get('question_reference_address')[1:])
|
||||
return self.execute(**self.node_params_serializer.data, question=str(question), exclude_paragraph_id_list=[])
|
||||
history_chat_record = self.flow_params_serializer.data.get('history_chat_record', [])
|
||||
paragraph_id_list = [p.get('id') for p in flat_map(
|
||||
[get_paragraph_list(chat_record, self.node.id) for chat_record in history_chat_record if
|
||||
chat_record.problem_text == question])]
|
||||
exclude_paragraph_id_list = list(set(paragraph_id_list))
|
||||
return self.execute(**self.node_params_serializer.data, question=str(question),
|
||||
exclude_paragraph_id_list=exclude_paragraph_id_list)
|
||||
|
||||
def execute(self, dataset_id_list, dataset_setting, question,
|
||||
exclude_paragraph_id_list=None,
|
||||
|
|
|
|||
|
|
@ -582,6 +582,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
'dataset_id_list': dataset_id_list}
|
||||
|
||||
def get_search_node(self, work_flow):
|
||||
if work_flow is None:
|
||||
return []
|
||||
return [node for node in work_flow.get('nodes', []) if node.get('type', '') == 'search-dataset-node']
|
||||
|
||||
def update_search_node(self, work_flow, user_dataset_id_list: List):
|
||||
|
|
|
|||
|
|
@ -96,9 +96,9 @@ class PGVector(BaseVectorStore):
|
|||
return []
|
||||
query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active)
|
||||
if exclude_document_id_list is not None and len(exclude_document_id_list) > 0:
|
||||
exclude_dict.__setitem__('document_id__in', exclude_document_id_list)
|
||||
query_set = query_set.exclude(document_id__in=exclude_document_id_list)
|
||||
if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0:
|
||||
exclude_dict.__setitem__('paragraph_id__in', exclude_paragraph_list)
|
||||
query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list)
|
||||
query_set = query_set.exclude(**exclude_dict)
|
||||
for search_handle in search_handle_list:
|
||||
if search_handle.support(search_mode):
|
||||
|
|
|
|||
Loading…
Reference in New Issue