mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
feat: 支持重排模型
This commit is contained in:
parent
fcbfd8a07c
commit
3faba75d4d
|
|
@ -14,9 +14,10 @@ from .start_node import *
|
|||
from .direct_reply_node import *
|
||||
from .function_lib_node import *
|
||||
from .function_node import *
|
||||
from .reranker_node import *
|
||||
|
||||
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
|
||||
BaseFunctionNodeNode, BaseFunctionLibNodeNode]
|
||||
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode]
|
||||
|
||||
|
||||
def get_node(node_type):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/9/4 11:37
|
||||
@desc:
|
||||
"""
|
||||
from .impl import *
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: i_reranker_node.py
|
||||
@date:2024/9/4 10:40
|
||||
@desc:
|
||||
"""
|
||||
from typing import Type
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from application.flow.i_step_node import INode, NodeResult
|
||||
from common.util.field_message import ErrMessage
|
||||
|
||||
|
||||
class RerankerSettingSerializer(serializers.Serializer):
|
||||
# 需要查询的条数
|
||||
top_n = serializers.IntegerField(required=True,
|
||||
error_messages=ErrMessage.integer("引用分段数"))
|
||||
# 相似度 0-1之间
|
||||
similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
|
||||
error_messages=ErrMessage.float("引用分段数"))
|
||||
max_paragraph_char_number = serializers.IntegerField(required=True,
|
||||
error_messages=ErrMessage.float("最大引用分段字数"))
|
||||
|
||||
|
||||
class RerankerStepNodeSerializer(serializers.Serializer):
|
||||
reranker_setting = RerankerSettingSerializer(required=True)
|
||||
|
||||
question_reference_address = serializers.ListField(required=True)
|
||||
reranker_model_id = serializers.UUIDField(required=True)
|
||||
reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
|
||||
|
||||
class IRerankerNode(INode):
|
||||
type = 'reranker-node'
|
||||
|
||||
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
|
||||
return RerankerStepNodeSerializer
|
||||
|
||||
def _run(self):
|
||||
question = self.workflow_manage.get_reference_field(
|
||||
self.node_params_serializer.data.get('question_reference_address')[0],
|
||||
self.node_params_serializer.data.get('question_reference_address')[1:])
|
||||
reranker_list = [self.workflow_manage.get_reference_field(
|
||||
reference[0],
|
||||
reference[1:]) for reference in
|
||||
self.node_params_serializer.data.get('reranker_reference_list')]
|
||||
return self.execute(**self.node_params_serializer.data, question=str(question),
|
||||
|
||||
reranker_list=reranker_list)
|
||||
|
||||
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
|
||||
**kwargs) -> NodeResult:
|
||||
pass
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/9/4 11:39
|
||||
@desc:
|
||||
"""
|
||||
from .base_reranker_node import *
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: base_reranker_node.py
|
||||
@date:2024/9/4 11:41
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from application.flow.i_step_node import NodeResult
|
||||
from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode
|
||||
from setting.models_provider.tools import get_model_instance_by_model_user_id
|
||||
|
||||
|
||||
def merge_reranker_list(reranker_list, result=None):
|
||||
if result is None:
|
||||
result = []
|
||||
for document in reranker_list:
|
||||
if isinstance(document, list):
|
||||
merge_reranker_list(document, result)
|
||||
elif isinstance(document, dict):
|
||||
content = document.get('title', '') + document.get('content', '')
|
||||
result.append(str(document) if len(content) == 0 else content)
|
||||
else:
|
||||
result.append(str(document))
|
||||
return result
|
||||
|
||||
|
||||
def filter_result(document_list: List[Document], max_paragraph_char_number, top_n, similarity):
|
||||
use_len = 0
|
||||
result = []
|
||||
for index in range(len(document_list)):
|
||||
document = document_list[index]
|
||||
if use_len >= max_paragraph_char_number or index >= top_n or document.metadata.get(
|
||||
'relevance_score') < similarity:
|
||||
break
|
||||
content = document.page_content[0:max_paragraph_char_number - use_len]
|
||||
use_len = use_len + len(content)
|
||||
result.append({'page_content': content, 'metadata': document.metadata})
|
||||
return result
|
||||
|
||||
|
||||
class BaseRerankerNode(IRerankerNode):
|
||||
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
|
||||
**kwargs) -> NodeResult:
|
||||
documents = merge_reranker_list(reranker_list)
|
||||
reranker_model = get_model_instance_by_model_user_id(reranker_model_id,
|
||||
self.flow_params_serializer.data.get('user_id'))
|
||||
result = reranker_model.compress_documents(
|
||||
[Document(page_content=document) for document in documents if document is not None and len(document) > 0],
|
||||
question)
|
||||
top_n = reranker_setting.get('top_n', 3)
|
||||
similarity = reranker_setting.get('similarity', 0.6)
|
||||
max_paragraph_char_number = reranker_setting.get('max_paragraph_char_number', 5000)
|
||||
r = filter_result(result, max_paragraph_char_number, top_n, similarity)
|
||||
return NodeResult({'result_list': r, 'result': ''.join([item.get('page_content') for item in r])}, {})
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
"question": self.node_params_serializer.data.get('question'),
|
||||
'run_time': self.context.get('run_time'),
|
||||
'type': self.node.type,
|
||||
'reranker_setting': self.node_params_serializer.data.get('reranker_setting'),
|
||||
'result_list': self.context.get('result_list'),
|
||||
'result': self.context.get('result'),
|
||||
'status': self.status,
|
||||
'err_message': self.err_message
|
||||
}
|
||||
|
|
@ -141,6 +141,7 @@ class ModelTypeConst(Enum):
|
|||
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
|
||||
STT = {'code': 'STT', 'message': '语音识别'}
|
||||
TTS = {'code': 'TTS', 'message': '语音合成'}
|
||||
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
|
||||
|
||||
|
||||
class ModelInfo:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,47 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: reranker.py
|
||||
@date:2024/9/3 14:33
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from setting.models_provider.impl.local_model_provider.model.reranker import LocalBaseReranker
|
||||
|
||||
|
||||
class LocalRerankerCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
|
||||
raise_exception=False):
|
||||
if not model_type == 'RERANKER':
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
|
||||
for key in ['cache_dir']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
model: LocalBaseReranker = provider.get_model(model_type, model_name, model_credential)
|
||||
model.compress_documents([Document(page_content='你好')], '你好')
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return model
|
||||
|
||||
cache_dir = forms.TextInputField('模型目录', required=True)
|
||||
|
|
@ -16,14 +16,20 @@ from common.util.file_util import get_file_content
|
|||
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
|
||||
ModelInfoManage
|
||||
from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
|
||||
from setting.models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential
|
||||
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||
from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
|
||||
LocalEmbeddingCredential(), LocalEmbedding)
|
||||
bge_reranker_v2_m3 = ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER,
|
||||
LocalRerankerCredential(), LocalReranker)
|
||||
|
||||
model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese)
|
||||
.append_default_model_info(embedding_text2vec_base_chinese)
|
||||
.append_model_info(bge_reranker_v2_m3)
|
||||
.append_default_model_info(bge_reranker_v2_m3)
|
||||
.build())
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,97 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: reranker.py.py
|
||||
@date:2024/9/2 16:42
|
||||
@desc:
|
||||
"""
|
||||
from typing import Sequence, Optional, Dict, Any
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from smartdoc.const import CONFIG
|
||||
|
||||
|
||||
class LocalReranker(MaxKBBaseModel):
|
||||
def __init__(self, model_name, top_n=3, cache_dir=None):
|
||||
super().__init__()
|
||||
self.model_name = model_name
|
||||
self.cache_dir = cache_dir
|
||||
self.top_n = top_n
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
if model_kwargs.get('use_local', True):
|
||||
return LocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
|
||||
model_kwargs={'device': model_credential.get('device', 'cpu')}
|
||||
|
||||
)
|
||||
return WebLocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
|
||||
model_kwargs={'device': model_credential.get('device')},
|
||||
**model_kwargs)
|
||||
|
||||
|
||||
class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
pass
|
||||
|
||||
model_id: str = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.model_id = kwargs.get('model_id', None)
|
||||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
res = requests.post(
|
||||
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/compress_documents',
|
||||
json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in
|
||||
documents], 'query': query}, headers={'Content-Type': 'application/json'})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document
|
||||
in result.get('data')]
|
||||
raise Exception(result.get('msg'))
|
||||
|
||||
|
||||
class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
|
||||
client: Any = None
|
||||
tokenizer: Any = None
|
||||
model: Optional[str] = None
|
||||
cache_dir: Optional[str] = None
|
||||
model_kwargs = {}
|
||||
|
||||
def __init__(self, model_name, cache_dir=None, **model_kwargs):
|
||||
super().__init__()
|
||||
self.model = model_name
|
||||
self.cache_dir = cache_dir
|
||||
self.model_kwargs = model_kwargs
|
||||
self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir)
|
||||
self.client = self.client.to(self.model_kwargs.get('device', 'cpu'))
|
||||
self.client.eval()
|
||||
|
||||
@staticmethod
|
||||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||||
return LocalBaseReranker(model_name, cache_dir=model_credential.get('cache_dir'), **model_kwargs)
|
||||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
Sequence[Document]:
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
|
||||
truncation=True, return_tensors='pt', max_length=512)
|
||||
scores = [torch.sigmoid(s).float().item() for s in
|
||||
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
|
||||
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
|
||||
for index
|
||||
in range(len(documents))]
|
||||
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
|
||||
return result
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
@desc:
|
||||
"""
|
||||
from django.db.models import QuerySet
|
||||
from langchain_core.documents import Document
|
||||
from rest_framework import serializers
|
||||
|
||||
from common.config.embedding_config import ModelManage
|
||||
|
|
@ -33,6 +34,16 @@ class EmbedQuery(serializers.Serializer):
|
|||
text = serializers.CharField(required=True, error_messages=ErrMessage.char("向量文本"))
|
||||
|
||||
|
||||
class CompressDocument(serializers.Serializer):
|
||||
page_content = serializers.CharField(required=True, error_messages=ErrMessage.char("文本"))
|
||||
metadata = serializers.DictField(required=False, error_messages=ErrMessage.dict("元数据"))
|
||||
|
||||
|
||||
class CompressDocuments(serializers.Serializer):
|
||||
documents = CompressDocument(required=True, many=True)
|
||||
query = serializers.CharField(required=True, error_messages=ErrMessage.char("查询query"))
|
||||
|
||||
|
||||
class ModelApplySerializers(serializers.Serializer):
|
||||
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
||||
|
||||
|
|
@ -51,3 +62,12 @@ class ModelApplySerializers(serializers.Serializer):
|
|||
|
||||
model = get_embedding_model(self.data.get('model_id'))
|
||||
return model.embed_query(instance.get('text'))
|
||||
|
||||
def compress_documents(self, instance, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
CompressDocuments(data=instance).is_valid(raise_exception=True)
|
||||
model = get_embedding_model(self.data.get('model_id'))
|
||||
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
|
||||
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
|
||||
instance.get('documents')], instance.get('query'))]
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ urlpatterns = [
|
|||
path('provider/model_form', views.Provide.ModelForm.as_view(),
|
||||
name="provider/model_form"),
|
||||
path('model', views.Model.as_view(), name='model'),
|
||||
path('model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(), name='model/model_params_form'),
|
||||
path('model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(),
|
||||
name='model/model_params_form'),
|
||||
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
|
||||
path('model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
|
||||
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
|
||||
|
|
@ -31,4 +32,6 @@ if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
|||
name='model/embed_documents'),
|
||||
path('model/<str:model_id>/embed_query', views.ModelApply.EmbedQuery.as_view(),
|
||||
name='model/embed_query'),
|
||||
path('model/<str:model_id>/compress_documents', views.ModelApply.CompressDocuments.as_view(),
|
||||
name='model/embed_query'),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -36,3 +36,13 @@ class ModelApply(APIView):
|
|||
def post(self, request: Request, model_id):
|
||||
return result.success(
|
||||
ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data))
|
||||
|
||||
class CompressDocuments(APIView):
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="重排序文档",
|
||||
operation_id="重排序文档",
|
||||
responses=result.get_default_response(),
|
||||
tags=["模型"])
|
||||
def post(self, request: Request, model_id):
|
||||
return result.success(
|
||||
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
|
||||
|
|
|
|||
|
|
@ -237,6 +237,19 @@ const getApplicationModel: (
|
|||
return get(`${prefix}/${application_id}/model`, loading)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前用户可使用的模型列表
|
||||
* @param application_id
|
||||
* @param loading
|
||||
* @query { query_text: string, top_number: number, similarity: number }
|
||||
* @returns
|
||||
*/
|
||||
const getApplicationRerankerModel: (
|
||||
application_id: string,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<Array<any>>> = (application_id, loading) => {
|
||||
return get(`${prefix}/${application_id}/model`, { model_type: 'RERANKER' }, loading)
|
||||
}
|
||||
/**
|
||||
* 发布应用
|
||||
* @param 参数
|
||||
|
|
@ -310,5 +323,6 @@ export default {
|
|||
postWorkflowChatOpen,
|
||||
listFunctionLib,
|
||||
getFunctionLib,
|
||||
getModelParamsForm
|
||||
getModelParamsForm,
|
||||
getApplicationRerankerModel
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none">
|
||||
<g clip-path="url(#clip0_6564_362280)">
|
||||
<path d="M0 4C0 1.79086 1.95367 0 4.36364 0H19.6364C22.0463 0 24 1.79086 24 4V20C24 22.2091 22.0463 24 19.6364 24H4.36364C1.95367 24 0 22.2091 0 20V4Z" fill="none"/>
|
||||
<path d="M8.00004 7.66667C8.00004 6.9303 7.40307 6.33333 6.66671 6.33333C5.93033 6.33333 5.33337 6.9303 5.33337 7.66667C5.33337 8.40303 5.93033 9 6.66671 9C7.40307 9 8.00004 8.40303 8.00004 7.66667Z" fill="white"/>
|
||||
<path d="M8.00004 16.3333C8.00004 15.5969 7.40307 15 6.66671 15C5.93033 15 5.33337 15.5969 5.33337 16.3333C5.33337 17.0697 5.93033 17.6666 6.66671 17.6666C7.40307 17.6666 8.00004 17.0697 8.00004 16.3333Z" fill="white"/>
|
||||
<path d="M19.3334 12C19.3334 10.8955 18.4379 10 17.3334 10C16.2288 10 15.3334 10.8955 15.3334 12C15.3334 13.1046 16.2288 14 17.3334 14C18.4379 14 19.3334 13.1046 19.3334 12Z" fill="white"/>
|
||||
<path d="M8.66663 7.33337C8.66663 6.22882 7.77118 5.33337 6.66663 5.33337C5.56208 5.33337 4.66663 6.22882 4.66663 7.33337C4.66663 8.43792 5.56208 9.33337 6.66663 9.33337C7.77118 9.33337 8.66663 8.43792 8.66663 7.33337Z" fill="white"/>
|
||||
<path d="M8.66663 16.6666C8.66663 15.5621 7.77118 14.6666 6.66663 14.6666C5.56208 14.6666 4.66663 15.5621 4.66663 16.6666C4.66663 17.7712 5.56208 18.6666 6.66663 18.6666C7.77118 18.6666 8.66663 17.7712 8.66663 16.6666Z" fill="white"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M7.33337 16.3333C7.33337 15.9651 7.63185 15.6667 8.00004 15.6667H9.00004C9.35848 15.6667 9.70703 15.5052 10.0995 15.1574C10.5012 14.8013 10.8785 14.3145 11.305 13.7602C11.3119 13.7511 11.3189 13.7421 11.3259 13.733C11.7275 13.2108 12.1767 12.6268 12.6828 12.1782C12.7517 12.1171 12.8229 12.0575 12.8962 12C12.8229 11.9425 12.7517 11.8829 12.6828 11.8218C12.1767 11.3732 11.7275 10.7892 11.3259 10.267L11.305 10.2398C10.8785 9.68546 10.5012 9.19868 10.0995 8.84265C9.70703 8.49477 9.35848 8.33333 9.00004 8.33333H8.00004C7.63185 8.33333 7.33337 8.03486 7.33337 7.66667C7.33337 7.29848 7.63185 7 8.00004 7H9.00004C9.80827 7 10.4597 7.38023 10.9839 7.84485C11.4901 8.2935 11.9392 8.87749 12.3409 9.39964L12.3618 9.42686C12.7882 9.9812 13.1656 10.468 13.5672 10.824C13.9597 11.1719 14.3083 11.3333 14.6667 11.3333H16C16.3682 11.3333 16.6667 11.6318 16.6667 12C16.6667 12.3682 16.3682 12.6667 16 12.6667H14.6667C14.3083 12.6667 13.9597 12.8281 13.5672 13.176C13.1656 13.532 12.7882 14.0188 12.3618 14.5731L12.3409 14.6004C11.9392 15.1225 11.4901 15.7065 10.9839 16.1551C10.4597 16.6198 9.80827 17 9.00004 17H8.00004C7.63185 17 7.33337 16.7015 7.33337 16.3333Z" fill="white"/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_6564_362280">
|
||||
<rect width="24" height="24" fill="white"/>
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.7 KiB |
|
|
@ -7,5 +7,6 @@ export enum WorkflowType {
|
|||
Condition = 'condition-node',
|
||||
Reply = 'reply-node',
|
||||
FunctionLib = 'function-lib-node',
|
||||
FunctionLibCustom = 'function-node'
|
||||
FunctionLibCustom = 'function-node',
|
||||
RrerankerNode = 'reranker-node'
|
||||
}
|
||||
|
|
|
|||
|
|
@ -139,7 +139,30 @@ export const replyNode = {
|
|||
}
|
||||
}
|
||||
}
|
||||
export const menuNodes = [aiChatNode, searchDatasetNode, questionNode, conditionNode, replyNode]
|
||||
export const rerankerNode = {
|
||||
type: WorkflowType.RrerankerNode,
|
||||
text: '使用重排模型对多个知识库的检索结果进行二次召回',
|
||||
label: '多路召回',
|
||||
properties: {
|
||||
stepName: '多路召回',
|
||||
config: {
|
||||
fields: [
|
||||
{
|
||||
label: '结果',
|
||||
value: 'result'
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
export const menuNodes = [
|
||||
aiChatNode,
|
||||
searchDatasetNode,
|
||||
questionNode,
|
||||
conditionNode,
|
||||
replyNode,
|
||||
rerankerNode
|
||||
]
|
||||
|
||||
/**
|
||||
* 自定义函数配置数据
|
||||
|
|
@ -203,7 +226,8 @@ export const nodeDict: any = {
|
|||
[WorkflowType.Start]: startNode,
|
||||
[WorkflowType.Reply]: replyNode,
|
||||
[WorkflowType.FunctionLib]: functionLibNode,
|
||||
[WorkflowType.FunctionLibCustom]: functionNode
|
||||
[WorkflowType.FunctionLibCustom]: functionNode,
|
||||
[WorkflowType.RrerankerNode]: rerankerNode
|
||||
}
|
||||
export function isWorkFlow(type: string | undefined) {
|
||||
return type === 'WORK_FLOW'
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
<template>
|
||||
<AppAvatar shape="square" style="background: #d136d1">
|
||||
<img src="@/assets/icon_reranker.svg" style="width: 75%" alt="" />
|
||||
</AppAvatar>
|
||||
</template>
|
||||
<script setup lang="ts"></script>
|
||||
|
|
@ -0,0 +1,266 @@
|
|||
<template>
|
||||
<el-dialog
|
||||
align-center
|
||||
:title="$t('views.application.applicationForm.dialogues.paramSettings')"
|
||||
class="param-dialog"
|
||||
v-model="dialogVisible"
|
||||
style="width: 550px"
|
||||
append-to-body
|
||||
>
|
||||
<div>
|
||||
<el-scrollbar always>
|
||||
<div class="p-16">
|
||||
<el-form label-position="top" ref="paramFormRef" :model="form">
|
||||
<el-row :gutter="20">
|
||||
<el-col :span="12">
|
||||
<el-form-item>
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<span class="mr-4">{{
|
||||
$t('views.application.applicationForm.dialogues.similarityThreshold')
|
||||
}}</span>
|
||||
<el-tooltip effect="dark" content="相似度越高相关性越强。" placement="right">
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<el-input-number
|
||||
v-model="form.similarity"
|
||||
:min="0"
|
||||
:max="form.search_mode === 'blend' ? 2 : 1"
|
||||
:precision="3"
|
||||
:step="0.1"
|
||||
:value-on-clear="0"
|
||||
controls-position="right"
|
||||
class="w-full"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="12">
|
||||
<el-form-item
|
||||
:label="$t('views.application.applicationForm.dialogues.topReferences')"
|
||||
>
|
||||
<el-input-number
|
||||
v-model="form.top_n"
|
||||
:min="1"
|
||||
:max="100"
|
||||
:value-on-clear="1"
|
||||
controls-position="right"
|
||||
class="w-full"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
</el-row>
|
||||
|
||||
<el-form-item :label="$t('views.application.applicationForm.dialogues.maxCharacters')">
|
||||
<el-slider
|
||||
v-model="form.max_paragraph_char_number"
|
||||
show-input
|
||||
:show-input-controls="false"
|
||||
:min="500"
|
||||
:max="100000"
|
||||
class="custom-slider"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item
|
||||
v-if="!isWorkflowType"
|
||||
:label="$t('views.application.applicationForm.dialogues.noReferencesAction')"
|
||||
>
|
||||
<el-form
|
||||
label-position="top"
|
||||
ref="noReferencesformRef"
|
||||
:model="noReferencesform"
|
||||
:rules="noReferencesRules"
|
||||
class="w-full"
|
||||
:hide-required-asterisk="true"
|
||||
>
|
||||
<el-radio-group
|
||||
v-model="form.no_references_setting.status"
|
||||
class="radio-block mb-16"
|
||||
>
|
||||
<div>
|
||||
<el-radio value="ai_questioning">
|
||||
<p>
|
||||
{{ $t('views.application.applicationForm.dialogues.continueQuestioning') }}
|
||||
</p>
|
||||
<el-form-item
|
||||
v-if="form.no_references_setting.status === 'ai_questioning'"
|
||||
:label="$t('views.application.applicationForm.form.prompt.label')"
|
||||
prop="ai_questioning"
|
||||
>
|
||||
<el-input
|
||||
v-model="noReferencesform.ai_questioning"
|
||||
:rows="2"
|
||||
type="textarea"
|
||||
maxlength="2048"
|
||||
:placeholder="defaultValue['ai_questioning']"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-radio>
|
||||
</div>
|
||||
<div class="mt-8">
|
||||
<el-radio value="designated_answer">
|
||||
<p>{{ $t('views.application.applicationForm.dialogues.provideAnswer') }}</p>
|
||||
<el-form-item
|
||||
v-if="form.no_references_setting.status === 'designated_answer'"
|
||||
prop="designated_answer"
|
||||
>
|
||||
<el-input
|
||||
v-model="noReferencesform.designated_answer"
|
||||
:rows="2"
|
||||
type="textarea"
|
||||
maxlength="2048"
|
||||
:placeholder="defaultValue['designated_answer']"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-radio>
|
||||
</div>
|
||||
</el-radio-group>
|
||||
</el-form>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</div>
|
||||
</el-scrollbar>
|
||||
</div>
|
||||
<template #footer>
|
||||
<span class="dialog-footer p-16">
|
||||
<el-button @click.prevent="dialogVisible = false">{{
|
||||
$t('views.application.applicationForm.buttons.cancel')
|
||||
}}</el-button>
|
||||
<el-button type="primary" @click="submit(noReferencesformRef)" :loading="loading">
|
||||
{{ $t('views.application.applicationForm.buttons.save') }}
|
||||
</el-button>
|
||||
</span>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, reactive } from 'vue'
|
||||
import { cloneDeep } from 'lodash'
|
||||
import type { FormInstance, FormRules } from 'element-plus'
|
||||
import { isWorkFlow } from '@/utils/application'
|
||||
import { t } from '@/locales'
|
||||
const emit = defineEmits(['refresh'])
|
||||
|
||||
const paramFormRef = ref()
|
||||
const noReferencesformRef = ref()
|
||||
|
||||
const defaultValue = {
|
||||
ai_questioning: '{question}',
|
||||
// @ts-ignore
|
||||
designated_answer: t('views.application.applicationForm.dialogues.designated_answer')
|
||||
}
|
||||
|
||||
const form = ref<any>({
|
||||
search_mode: 'embedding',
|
||||
top_n: 3,
|
||||
similarity: 0.6,
|
||||
max_paragraph_char_number: 5000,
|
||||
no_references_setting: {
|
||||
status: 'ai_questioning',
|
||||
value: '{question}'
|
||||
}
|
||||
})
|
||||
|
||||
const noReferencesform = ref<any>({
|
||||
ai_questioning: defaultValue['ai_questioning'],
|
||||
designated_answer: defaultValue['designated_answer']
|
||||
})
|
||||
|
||||
const noReferencesRules = reactive<FormRules<any>>({
|
||||
ai_questioning: [
|
||||
{
|
||||
required: true,
|
||||
message: t('views.application.applicationForm.dialogues.promptPlaceholder'),
|
||||
trigger: 'blur'
|
||||
}
|
||||
],
|
||||
designated_answer: [
|
||||
{
|
||||
required: true,
|
||||
message: t('views.application.applicationForm.dialogues.concentPlaceholder'),
|
||||
trigger: 'blur'
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
const dialogVisible = ref<boolean>(false)
|
||||
const loading = ref(false)
|
||||
|
||||
const isWorkflowType = ref(false)
|
||||
|
||||
watch(dialogVisible, (bool) => {
|
||||
if (!bool) {
|
||||
form.value = {
|
||||
search_mode: 'embedding',
|
||||
top_n: 3,
|
||||
similarity: 0.6,
|
||||
max_paragraph_char_number: 5000,
|
||||
no_references_setting: {
|
||||
status: 'ai_questioning',
|
||||
value: ''
|
||||
}
|
||||
}
|
||||
noReferencesform.value = {
|
||||
ai_questioning: defaultValue['ai_questioning'],
|
||||
designated_answer: defaultValue['designated_answer']
|
||||
}
|
||||
noReferencesformRef.value?.clearValidate()
|
||||
}
|
||||
})
|
||||
|
||||
const open = (data: any, type?: string) => {
|
||||
isWorkflowType.value = isWorkFlow(type)
|
||||
form.value = { ...form.value, ...cloneDeep(data) }
|
||||
noReferencesform.value[form.value.no_references_setting.status] =
|
||||
form.value.no_references_setting.value
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
const submit = async (formEl: FormInstance | undefined) => {
|
||||
if (isWorkflowType.value) {
|
||||
delete form.value['no_references_setting']
|
||||
emit('refresh', form.value)
|
||||
dialogVisible.value = false
|
||||
} else {
|
||||
if (!formEl) return
|
||||
await formEl.validate((valid, fields) => {
|
||||
if (valid) {
|
||||
form.value.no_references_setting.value =
|
||||
noReferencesform.value[form.value.no_references_setting.status]
|
||||
emit('refresh', form.value)
|
||||
dialogVisible.value = false
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
function changeHandle(val: string) {
|
||||
if (val === 'keywords') {
|
||||
form.value.similarity = 0
|
||||
} else {
|
||||
form.value.similarity = 0.6
|
||||
}
|
||||
}
|
||||
|
||||
defineExpose({ open })
|
||||
</script>
|
||||
<style lang="scss" scope>
|
||||
.param-dialog {
|
||||
padding: 8px 8px 24px 8px;
|
||||
.el-dialog__header {
|
||||
padding: 16px 16px 0 16px;
|
||||
}
|
||||
.el-dialog__body {
|
||||
padding: 0 !important;
|
||||
}
|
||||
.dialog-max-height {
|
||||
height: 550px;
|
||||
}
|
||||
.custom-slider {
|
||||
.el-input-number.is-without-controls .el-input__wrapper {
|
||||
padding: 0 !important;
|
||||
}
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
import RerankerNodeVue from './index.vue'
|
||||
import { AppNode, AppNodeModel } from '@/workflow/common/app-node'
|
||||
class RerankerNode extends AppNode {
|
||||
constructor(props: any) {
|
||||
super(props, RerankerNodeVue)
|
||||
}
|
||||
}
|
||||
export default {
|
||||
type: 'reranker-node',
|
||||
model: AppNodeModel,
|
||||
view: RerankerNode
|
||||
}
|
||||
|
|
@ -0,0 +1,303 @@
|
|||
<template>
|
||||
<NodeContainer :nodeModel="nodeModel">
|
||||
<el-card shadow="never" class="card-never" style="--el-card-padding: 12px">
|
||||
<el-form
|
||||
@submit.prevent
|
||||
:model="form_data"
|
||||
label-position="top"
|
||||
require-asterisk-position="right"
|
||||
label-width="auto"
|
||||
ref="rerankerNodeFormRef"
|
||||
>
|
||||
<el-form-item label="知识库检索结果">
|
||||
<template #label>
|
||||
<div class="flex-between">
|
||||
<span>知识库检索结果</span>
|
||||
<el-button @click="add_reranker_reference" link type="primary">
|
||||
<el-icon class="mr-4"><Plus /></el-icon>
|
||||
</el-button>
|
||||
</div>
|
||||
</template>
|
||||
<el-row
|
||||
:gutter="8"
|
||||
style="margin-bottom: 8px"
|
||||
v-for="(reranker_reference, index) in form_data.reranker_reference_list"
|
||||
>
|
||||
<el-col :span="22">
|
||||
<el-form-item
|
||||
:prop="'reranker_reference_list.' + index"
|
||||
:rules="{
|
||||
type: 'array',
|
||||
required: true,
|
||||
message: '请选择变量',
|
||||
trigger: 'change'
|
||||
}"
|
||||
>
|
||||
<NodeCascader
|
||||
:key="index"
|
||||
ref="nodeCascaderRef"
|
||||
:nodeModel="nodeModel"
|
||||
class="w-full"
|
||||
placeholder="请选择检索问题输入"
|
||||
v-model="form_data.reranker_reference_list[index]"
|
||||
/>
|
||||
</el-form-item>
|
||||
</el-col>
|
||||
<el-col :span="1">
|
||||
<el-button link type="info" class="mt-4" @click="deleteCondition(index)">
|
||||
<el-icon><Delete /></el-icon>
|
||||
</el-button>
|
||||
</el-col>
|
||||
</el-row>
|
||||
</el-form-item>
|
||||
<el-form-item label="检索参数">
|
||||
<template #label>
|
||||
<div class="flex-between">
|
||||
<span>检索参数</span>
|
||||
<el-button type="primary" link @click="openParamSettingDialog">
|
||||
<el-icon><Setting /></el-icon>
|
||||
</el-button>
|
||||
</div>
|
||||
</template>
|
||||
<div class="w-full">
|
||||
<el-row>
|
||||
<el-col :span="12" class="color-secondary lighter"> 相似度高于</el-col>
|
||||
<el-col :span="12" class="lighter">
|
||||
{{ form_data.reranker_setting.similarity }}</el-col
|
||||
>
|
||||
<el-col :span="12" class="color-secondary lighter"> 引用分段 Top</el-col>
|
||||
<el-col :span="12" class="lighter"> {{ form_data.reranker_setting.top_n }}</el-col>
|
||||
<el-col :span="12" class="color-secondary lighter"> 最大引用字符数</el-col>
|
||||
<el-col :span="12" class="lighter">
|
||||
{{ form_data.reranker_setting.max_paragraph_char_number }}</el-col
|
||||
>
|
||||
</el-row>
|
||||
</div>
|
||||
</el-form-item>
|
||||
<el-form-item
|
||||
label="检索问题"
|
||||
prop="question_reference_address"
|
||||
:rules="{
|
||||
message: '请选择检索问题输入',
|
||||
trigger: 'blur',
|
||||
required: true
|
||||
}"
|
||||
>
|
||||
<NodeCascader
|
||||
ref="nodeCascaderRef"
|
||||
:nodeModel="nodeModel"
|
||||
class="w-full"
|
||||
placeholder="检索问题"
|
||||
v-model="form_data.question_reference_address"
|
||||
/>
|
||||
</el-form-item>
|
||||
<el-form-item
|
||||
label="AI 模型"
|
||||
prop="reranker_model_id"
|
||||
:rules="{
|
||||
required: true,
|
||||
message: '请选择 AI 模型',
|
||||
trigger: 'change'
|
||||
}"
|
||||
>
|
||||
<el-select
|
||||
@wheel="wheel"
|
||||
:teleported="false"
|
||||
v-model="form_data.reranker_model_id"
|
||||
placeholder="请选择 AI 模型"
|
||||
class="w-full"
|
||||
popper-class="select-model"
|
||||
:clearable="true"
|
||||
>
|
||||
<el-option-group
|
||||
v-for="(value, label) in modelOptions"
|
||||
:key="value"
|
||||
:label="relatedObject(providerOptions, label, 'provider')?.name"
|
||||
>
|
||||
<el-option
|
||||
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
class="flex-between"
|
||||
>
|
||||
<div class="flex align-center">
|
||||
<span
|
||||
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||
class="model-icon mr-8"
|
||||
></span>
|
||||
<span>{{ item.name }}</span>
|
||||
<el-tag v-if="item.permission_type === 'PUBLIC'" type="info" class="info-tag ml-8"
|
||||
>公用
|
||||
</el-tag>
|
||||
</div>
|
||||
<el-icon class="check-icon" v-if="item.id === form_data.model_id"
|
||||
><Check
|
||||
/></el-icon>
|
||||
</el-option>
|
||||
<el-option
|
||||
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
|
||||
:key="item.id"
|
||||
:label="item.name"
|
||||
:value="item.id"
|
||||
class="flex-between"
|
||||
disabled
|
||||
>
|
||||
<div class="flex">
|
||||
<span
|
||||
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
|
||||
class="model-icon mr-8"
|
||||
></span>
|
||||
<span>{{ item.name }}</span>
|
||||
<span class="danger">(不可用)</span>
|
||||
</div>
|
||||
<el-icon class="check-icon" v-if="item.id === form_data.model_id"
|
||||
><Check
|
||||
/></el-icon>
|
||||
</el-option>
|
||||
</el-option-group>
|
||||
<template #footer>
|
||||
<div class="w-full text-left cursor" @click="openCreateModel()">
|
||||
<el-button type="primary" link>
|
||||
<el-icon class="mr-4"><Plus /></el-icon>
|
||||
添加模型
|
||||
</el-button>
|
||||
</div>
|
||||
</template>
|
||||
</el-select>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
</el-card>
|
||||
<ParamSettingDialog ref="ParamSettingDialogRef" @refresh="refreshParam" />
|
||||
<!-- 添加模版 -->
|
||||
<CreateModelDialog
|
||||
ref="createModelRef"
|
||||
@submit="getModel"
|
||||
@change="openCreateModel($event)"
|
||||
></CreateModelDialog>
|
||||
</NodeContainer>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { set, cloneDeep, groupBy } from 'lodash'
|
||||
import NodeContainer from '@/workflow/common/NodeContainer.vue'
|
||||
import NodeCascader from '@/workflow/common/NodeCascader.vue'
|
||||
import ParamSettingDialog from './ParamSettingDialog.vue'
|
||||
import { ref, computed, onMounted } from 'vue'
|
||||
import { isLastNode } from '@/workflow/common/data'
|
||||
import type { Provider } from '@/api/type/model'
|
||||
import CreateModelDialog from '@/views/template/component/CreateModelDialog.vue'
|
||||
import SelectProviderDialog from '@/views/template/component/SelectProviderDialog.vue'
|
||||
import applicationApi from '@/api/application'
|
||||
import useStore from '@/stores'
|
||||
import { app } from '@/main'
|
||||
import { relatedObject } from '@/utils/utils'
|
||||
|
||||
const { model } = useStore()
|
||||
const props = defineProps<{ nodeModel: any }>()
|
||||
const createModelRef = ref<InstanceType<typeof CreateModelDialog>>()
|
||||
const selectProviderRef = ref<InstanceType<typeof SelectProviderDialog>>()
|
||||
const ParamSettingDialogRef = ref<InstanceType<typeof ParamSettingDialog>>()
|
||||
const {
|
||||
params: { id }
|
||||
} = app.config.globalProperties.$route as any
|
||||
const form = {
|
||||
reranker_reference_list: [],
|
||||
reranker_model_id: '',
|
||||
reranker_setting: {
|
||||
top_n: 3,
|
||||
similarity: 0.6,
|
||||
max_paragraph_char_number: 5000
|
||||
}
|
||||
}
|
||||
const providerOptions = ref<Array<Provider>>([])
|
||||
const modelOptions = ref<any>(null)
|
||||
const openParamSettingDialog = () => {
|
||||
ParamSettingDialogRef.value?.open(form_data.value.dataset_setting, 'WORK_FLOW')
|
||||
}
|
||||
const deleteCondition = (index: number) => {
|
||||
const list = cloneDeep(props.nodeModel.properties.node_data.reranker_reference_list)
|
||||
list.splice(index, 1)
|
||||
set(props.nodeModel.properties.node_data, 'reranker_reference_list', list)
|
||||
}
|
||||
const wheel = (e: any) => {
|
||||
if (e.ctrlKey === true) {
|
||||
e.preventDefault()
|
||||
return true
|
||||
} else {
|
||||
e.stopPropagation()
|
||||
return true
|
||||
}
|
||||
}
|
||||
const form_data = computed({
|
||||
get: () => {
|
||||
if (props.nodeModel.properties.node_data) {
|
||||
return props.nodeModel.properties.node_data
|
||||
} else {
|
||||
set(props.nodeModel.properties, 'node_data', form)
|
||||
}
|
||||
return props.nodeModel.properties.node_data
|
||||
},
|
||||
set: (value) => {
|
||||
set(props.nodeModel.properties, 'node_data', value)
|
||||
}
|
||||
})
|
||||
function refreshParam(data: any) {
|
||||
set(props.nodeModel.properties.node_data, 'reranker_setting', data)
|
||||
}
|
||||
function getModel() {
|
||||
if (id) {
|
||||
applicationApi.getApplicationRerankerModel(id).then((res: any) => {
|
||||
modelOptions.value = groupBy(res?.data, 'provider')
|
||||
})
|
||||
} else {
|
||||
model.asyncGetModel({ model_type: 'RERANKER' }).then((res: any) => {
|
||||
modelOptions.value = groupBy(res?.data, 'provider')
|
||||
})
|
||||
}
|
||||
}
|
||||
function getProvider() {
|
||||
model.asyncGetProvider().then((res: any) => {
|
||||
providerOptions.value = res?.data
|
||||
})
|
||||
}
|
||||
const add_reranker_reference = () => {
|
||||
const list = cloneDeep(props.nodeModel.properties.node_data.reranker_reference_list)
|
||||
list.push([])
|
||||
set(props.nodeModel.properties.node_data, 'reranker_reference_list', list)
|
||||
}
|
||||
const rerankerNodeFormRef = ref()
|
||||
const nodeCascaderRef = ref()
|
||||
const validate = () => {
|
||||
return Promise.all([
|
||||
nodeCascaderRef.value ? nodeCascaderRef.value.validate() : Promise.resolve(''),
|
||||
rerankerNodeFormRef.value?.validate()
|
||||
]).catch((err: any) => {
|
||||
return Promise.reject({ node: props.nodeModel, errMessage: err })
|
||||
})
|
||||
}
|
||||
const openCreateModel = (provider?: Provider) => {
|
||||
if (provider && provider.provider) {
|
||||
createModelRef.value?.open(provider)
|
||||
} else {
|
||||
selectProviderRef.value?.open()
|
||||
}
|
||||
}
|
||||
onMounted(() => {
|
||||
getProvider()
|
||||
getModel()
|
||||
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
|
||||
if (isLastNode(props.nodeModel)) {
|
||||
set(props.nodeModel.properties.node_data, 'is_result', true)
|
||||
}
|
||||
}
|
||||
|
||||
set(props.nodeModel, 'validate', validate)
|
||||
})
|
||||
</script>
|
||||
<style lang="scss" scoped>
|
||||
.reply-node-editor {
|
||||
:deep(.md-editor-footer) {
|
||||
border: none !important;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
Loading…
Reference in New Issue