feat: 支持重排模型

This commit is contained in:
shaohuzhang1 2024-09-05 11:28:21 +08:00 committed by shaohuzhang1
parent fcbfd8a07c
commit 3faba75d4d
20 changed files with 983 additions and 6 deletions

View File

@ -14,9 +14,10 @@ from .start_node import *
from .direct_reply_node import *
from .function_lib_node import *
from .function_node import *
from .reranker_node import *
node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode,
BaseFunctionNodeNode, BaseFunctionLibNodeNode]
BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode]
def get_node(node_type):

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/9/4 11:37
@desc:
"""
from .impl import *

View File

@ -0,0 +1,59 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file i_reranker_node.py
@date2024/9/4 10:40
@desc:
"""
from typing import Type
from rest_framework import serializers
from application.flow.i_step_node import INode, NodeResult
from common.util.field_message import ErrMessage
class RerankerSettingSerializer(serializers.Serializer):
# 需要查询的条数
top_n = serializers.IntegerField(required=True,
error_messages=ErrMessage.integer("引用分段数"))
# 相似度 0-1之间
similarity = serializers.FloatField(required=True, max_value=2, min_value=0,
error_messages=ErrMessage.float("引用分段数"))
max_paragraph_char_number = serializers.IntegerField(required=True,
error_messages=ErrMessage.float("最大引用分段字数"))
class RerankerStepNodeSerializer(serializers.Serializer):
reranker_setting = RerankerSettingSerializer(required=True)
question_reference_address = serializers.ListField(required=True)
reranker_model_id = serializers.UUIDField(required=True)
reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
class IRerankerNode(INode):
type = 'reranker-node'
def get_node_params_serializer_class(self) -> Type[serializers.Serializer]:
return RerankerStepNodeSerializer
def _run(self):
question = self.workflow_manage.get_reference_field(
self.node_params_serializer.data.get('question_reference_address')[0],
self.node_params_serializer.data.get('question_reference_address')[1:])
reranker_list = [self.workflow_manage.get_reference_field(
reference[0],
reference[1:]) for reference in
self.node_params_serializer.data.get('reranker_reference_list')]
return self.execute(**self.node_params_serializer.data, question=str(question),
reranker_list=reranker_list)
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
**kwargs) -> NodeResult:
pass

View File

@ -0,0 +1,9 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/9/4 11:39
@desc:
"""
from .base_reranker_node import *

View File

@ -0,0 +1,73 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file base_reranker_node.py
@date2024/9/4 11:41
@desc:
"""
from typing import List
from langchain_core.documents import Document
from application.flow.i_step_node import NodeResult
from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode
from setting.models_provider.tools import get_model_instance_by_model_user_id
def merge_reranker_list(reranker_list, result=None):
if result is None:
result = []
for document in reranker_list:
if isinstance(document, list):
merge_reranker_list(document, result)
elif isinstance(document, dict):
content = document.get('title', '') + document.get('content', '')
result.append(str(document) if len(content) == 0 else content)
else:
result.append(str(document))
return result
def filter_result(document_list: List[Document], max_paragraph_char_number, top_n, similarity):
use_len = 0
result = []
for index in range(len(document_list)):
document = document_list[index]
if use_len >= max_paragraph_char_number or index >= top_n or document.metadata.get(
'relevance_score') < similarity:
break
content = document.page_content[0:max_paragraph_char_number - use_len]
use_len = use_len + len(content)
result.append({'page_content': content, 'metadata': document.metadata})
return result
class BaseRerankerNode(IRerankerNode):
def execute(self, question, reranker_setting, reranker_list, reranker_model_id,
**kwargs) -> NodeResult:
documents = merge_reranker_list(reranker_list)
reranker_model = get_model_instance_by_model_user_id(reranker_model_id,
self.flow_params_serializer.data.get('user_id'))
result = reranker_model.compress_documents(
[Document(page_content=document) for document in documents if document is not None and len(document) > 0],
question)
top_n = reranker_setting.get('top_n', 3)
similarity = reranker_setting.get('similarity', 0.6)
max_paragraph_char_number = reranker_setting.get('max_paragraph_char_number', 5000)
r = filter_result(result, max_paragraph_char_number, top_n, similarity)
return NodeResult({'result_list': r, 'result': ''.join([item.get('page_content') for item in r])}, {})
def get_details(self, index: int, **kwargs):
return {
'name': self.node.properties.get('stepName'),
"index": index,
"question": self.node_params_serializer.data.get('question'),
'run_time': self.context.get('run_time'),
'type': self.node.type,
'reranker_setting': self.node_params_serializer.data.get('reranker_setting'),
'result_list': self.context.get('result_list'),
'result': self.context.get('result'),
'status': self.status,
'err_message': self.err_message
}

View File

@ -141,6 +141,7 @@ class ModelTypeConst(Enum):
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
STT = {'code': 'STT', 'message': '语音识别'}
TTS = {'code': 'TTS', 'message': '语音合成'}
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}
class ModelInfo:

View File

@ -0,0 +1,47 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file reranker.py
@date2024/9/3 14:33
@desc:
"""
from typing import Dict
from langchain_core.documents import Document
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.local_model_provider.model.reranker import LocalBaseReranker
class LocalRerankerCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
if not model_type == 'RERANKER':
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['cache_dir']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model: LocalBaseReranker = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content='你好')], '你好')
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return model
cache_dir = forms.TextInputField('模型目录', required=True)

View File

@ -16,14 +16,20 @@ from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential
from setting.models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential
from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker
from smartdoc.conf import PROJECT_DIR
embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING,
LocalEmbeddingCredential(), LocalEmbedding)
bge_reranker_v2_m3 = ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER,
LocalRerankerCredential(), LocalReranker)
model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese)
.append_default_model_info(embedding_text2vec_base_chinese)
.append_model_info(bge_reranker_v2_m3)
.append_default_model_info(bge_reranker_v2_m3)
.build())

View File

@ -0,0 +1,97 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file reranker.py.py
@date2024/9/2 16:42
@desc:
"""
from typing import Sequence, Optional, Dict, Any
import requests
import torch
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from setting.models_provider.base_model_provider import MaxKBBaseModel
from smartdoc.const import CONFIG
class LocalReranker(MaxKBBaseModel):
def __init__(self, model_name, top_n=3, cache_dir=None):
super().__init__()
self.model_name = model_name
self.cache_dir = cache_dir
self.top_n = top_n
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
if model_kwargs.get('use_local', True):
return LocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
model_kwargs={'device': model_credential.get('device', 'cpu')}
)
return WebLocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'),
model_kwargs={'device': model_credential.get('device')},
**model_kwargs)
class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
pass
model_id: str = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model_id = kwargs.get('model_id', None)
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
res = requests.post(
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/compress_documents',
json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in
documents], 'query': query}, headers={'Content-Type': 'application/json'})
result = res.json()
if result.get('code', 500) == 200:
return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document
in result.get('data')]
raise Exception(result.get('msg'))
class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
client: Any = None
tokenizer: Any = None
model: Optional[str] = None
cache_dir: Optional[str] = None
model_kwargs = {}
def __init__(self, model_name, cache_dir=None, **model_kwargs):
super().__init__()
self.model = model_name
self.cache_dir = cache_dir
self.model_kwargs = model_kwargs
self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir)
self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir)
self.client = self.client.to(self.model_kwargs.get('device', 'cpu'))
self.client.eval()
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return LocalBaseReranker(model_name, cache_dir=model_credential.get('cache_dir'), **model_kwargs)
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
with torch.no_grad():
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
truncation=True, return_tensors='pt', max_length=512)
scores = [torch.sigmoid(s).float().item() for s in
self.client(**inputs, return_dict=True).logits.view(-1, ).float()]
result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]})
for index
in range(len(documents))]
result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True)
return result

View File

@ -7,6 +7,7 @@
@desc:
"""
from django.db.models import QuerySet
from langchain_core.documents import Document
from rest_framework import serializers
from common.config.embedding_config import ModelManage
@ -33,6 +34,16 @@ class EmbedQuery(serializers.Serializer):
text = serializers.CharField(required=True, error_messages=ErrMessage.char("向量文本"))
class CompressDocument(serializers.Serializer):
page_content = serializers.CharField(required=True, error_messages=ErrMessage.char("文本"))
metadata = serializers.DictField(required=False, error_messages=ErrMessage.dict("元数据"))
class CompressDocuments(serializers.Serializer):
documents = CompressDocument(required=True, many=True)
query = serializers.CharField(required=True, error_messages=ErrMessage.char("查询query"))
class ModelApplySerializers(serializers.Serializer):
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
@ -51,3 +62,12 @@ class ModelApplySerializers(serializers.Serializer):
model = get_embedding_model(self.data.get('model_id'))
return model.embed_query(instance.get('text'))
def compress_documents(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
CompressDocuments(data=instance).is_valid(raise_exception=True)
model = get_embedding_model(self.data.get('model_id'))
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
instance.get('documents')], instance.get('query'))]

View File

@ -17,7 +17,8 @@ urlpatterns = [
path('provider/model_form', views.Provide.ModelForm.as_view(),
name="provider/model_form"),
path('model', views.Model.as_view(), name='model'),
path('model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(), name='model/model_params_form'),
path('model/<str:model_id>/model_params_form', views.Model.ModelParamsForm.as_view(),
name='model/model_params_form'),
path('model/<str:model_id>', views.Model.Operate.as_view(), name='model/operate'),
path('model/<str:model_id>/pause_download', views.Model.PauseDownload.as_view(), name='model/operate'),
path('model/<str:model_id>/meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'),
@ -31,4 +32,6 @@ if os.environ.get('SERVER_NAME', 'web') == 'local_model':
name='model/embed_documents'),
path('model/<str:model_id>/embed_query', views.ModelApply.EmbedQuery.as_view(),
name='model/embed_query'),
path('model/<str:model_id>/compress_documents', views.ModelApply.CompressDocuments.as_view(),
name='model/embed_query'),
]

View File

@ -36,3 +36,13 @@ class ModelApply(APIView):
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data))
class CompressDocuments(APIView):
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="重排序文档",
operation_id="重排序文档",
responses=result.get_default_response(),
tags=["模型"])
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))

View File

@ -237,6 +237,19 @@ const getApplicationModel: (
return get(`${prefix}/${application_id}/model`, loading)
}
/**
* 使
* @param application_id
* @param loading
* @query { query_text: string, top_number: number, similarity: number }
* @returns
*/
const getApplicationRerankerModel: (
application_id: string,
loading?: Ref<boolean>
) => Promise<Result<Array<any>>> = (application_id, loading) => {
return get(`${prefix}/${application_id}/model`, { model_type: 'RERANKER' }, loading)
}
/**
*
* @param
@ -310,5 +323,6 @@ export default {
postWorkflowChatOpen,
listFunctionLib,
getFunctionLib,
getModelParamsForm
getModelParamsForm,
getApplicationRerankerModel
}

View File

@ -0,0 +1,16 @@
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" viewBox="0 0 24 24" fill="none">
<g clip-path="url(#clip0_6564_362280)">
<path d="M0 4C0 1.79086 1.95367 0 4.36364 0H19.6364C22.0463 0 24 1.79086 24 4V20C24 22.2091 22.0463 24 19.6364 24H4.36364C1.95367 24 0 22.2091 0 20V4Z" fill="none"/>
<path d="M8.00004 7.66667C8.00004 6.9303 7.40307 6.33333 6.66671 6.33333C5.93033 6.33333 5.33337 6.9303 5.33337 7.66667C5.33337 8.40303 5.93033 9 6.66671 9C7.40307 9 8.00004 8.40303 8.00004 7.66667Z" fill="white"/>
<path d="M8.00004 16.3333C8.00004 15.5969 7.40307 15 6.66671 15C5.93033 15 5.33337 15.5969 5.33337 16.3333C5.33337 17.0697 5.93033 17.6666 6.66671 17.6666C7.40307 17.6666 8.00004 17.0697 8.00004 16.3333Z" fill="white"/>
<path d="M19.3334 12C19.3334 10.8955 18.4379 10 17.3334 10C16.2288 10 15.3334 10.8955 15.3334 12C15.3334 13.1046 16.2288 14 17.3334 14C18.4379 14 19.3334 13.1046 19.3334 12Z" fill="white"/>
<path d="M8.66663 7.33337C8.66663 6.22882 7.77118 5.33337 6.66663 5.33337C5.56208 5.33337 4.66663 6.22882 4.66663 7.33337C4.66663 8.43792 5.56208 9.33337 6.66663 9.33337C7.77118 9.33337 8.66663 8.43792 8.66663 7.33337Z" fill="white"/>
<path d="M8.66663 16.6666C8.66663 15.5621 7.77118 14.6666 6.66663 14.6666C5.56208 14.6666 4.66663 15.5621 4.66663 16.6666C4.66663 17.7712 5.56208 18.6666 6.66663 18.6666C7.77118 18.6666 8.66663 17.7712 8.66663 16.6666Z" fill="white"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M7.33337 16.3333C7.33337 15.9651 7.63185 15.6667 8.00004 15.6667H9.00004C9.35848 15.6667 9.70703 15.5052 10.0995 15.1574C10.5012 14.8013 10.8785 14.3145 11.305 13.7602C11.3119 13.7511 11.3189 13.7421 11.3259 13.733C11.7275 13.2108 12.1767 12.6268 12.6828 12.1782C12.7517 12.1171 12.8229 12.0575 12.8962 12C12.8229 11.9425 12.7517 11.8829 12.6828 11.8218C12.1767 11.3732 11.7275 10.7892 11.3259 10.267L11.305 10.2398C10.8785 9.68546 10.5012 9.19868 10.0995 8.84265C9.70703 8.49477 9.35848 8.33333 9.00004 8.33333H8.00004C7.63185 8.33333 7.33337 8.03486 7.33337 7.66667C7.33337 7.29848 7.63185 7 8.00004 7H9.00004C9.80827 7 10.4597 7.38023 10.9839 7.84485C11.4901 8.2935 11.9392 8.87749 12.3409 9.39964L12.3618 9.42686C12.7882 9.9812 13.1656 10.468 13.5672 10.824C13.9597 11.1719 14.3083 11.3333 14.6667 11.3333H16C16.3682 11.3333 16.6667 11.6318 16.6667 12C16.6667 12.3682 16.3682 12.6667 16 12.6667H14.6667C14.3083 12.6667 13.9597 12.8281 13.5672 13.176C13.1656 13.532 12.7882 14.0188 12.3618 14.5731L12.3409 14.6004C11.9392 15.1225 11.4901 15.7065 10.9839 16.1551C10.4597 16.6198 9.80827 17 9.00004 17H8.00004C7.63185 17 7.33337 16.7015 7.33337 16.3333Z" fill="white"/>
</g>
<defs>
<clipPath id="clip0_6564_362280">
<rect width="24" height="24" fill="white"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 2.7 KiB

View File

@ -7,5 +7,6 @@ export enum WorkflowType {
Condition = 'condition-node',
Reply = 'reply-node',
FunctionLib = 'function-lib-node',
FunctionLibCustom = 'function-node'
FunctionLibCustom = 'function-node',
RrerankerNode = 'reranker-node'
}

View File

@ -139,7 +139,30 @@ export const replyNode = {
}
}
}
export const menuNodes = [aiChatNode, searchDatasetNode, questionNode, conditionNode, replyNode]
export const rerankerNode = {
type: WorkflowType.RrerankerNode,
text: '使用重排模型对多个知识库的检索结果进行二次召回',
label: '多路召回',
properties: {
stepName: '多路召回',
config: {
fields: [
{
label: '结果',
value: 'result'
}
]
}
}
}
export const menuNodes = [
aiChatNode,
searchDatasetNode,
questionNode,
conditionNode,
replyNode,
rerankerNode
]
/**
*
@ -203,7 +226,8 @@ export const nodeDict: any = {
[WorkflowType.Start]: startNode,
[WorkflowType.Reply]: replyNode,
[WorkflowType.FunctionLib]: functionLibNode,
[WorkflowType.FunctionLibCustom]: functionNode
[WorkflowType.FunctionLibCustom]: functionNode,
[WorkflowType.RrerankerNode]: rerankerNode
}
export function isWorkFlow(type: string | undefined) {
return type === 'WORK_FLOW'

View File

@ -0,0 +1,6 @@
<template>
<AppAvatar shape="square" style="background: #d136d1">
<img src="@/assets/icon_reranker.svg" style="width: 75%" alt="" />
</AppAvatar>
</template>
<script setup lang="ts"></script>

View File

@ -0,0 +1,266 @@
<template>
<el-dialog
align-center
:title="$t('views.application.applicationForm.dialogues.paramSettings')"
class="param-dialog"
v-model="dialogVisible"
style="width: 550px"
append-to-body
>
<div>
<el-scrollbar always>
<div class="p-16">
<el-form label-position="top" ref="paramFormRef" :model="form">
<el-row :gutter="20">
<el-col :span="12">
<el-form-item>
<template #label>
<div class="flex align-center">
<span class="mr-4">{{
$t('views.application.applicationForm.dialogues.similarityThreshold')
}}</span>
<el-tooltip effect="dark" content="相似度越高相关性越强。" placement="right">
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-input-number
v-model="form.similarity"
:min="0"
:max="form.search_mode === 'blend' ? 2 : 1"
:precision="3"
:step="0.1"
:value-on-clear="0"
controls-position="right"
class="w-full"
/>
</el-form-item>
</el-col>
<el-col :span="12">
<el-form-item
:label="$t('views.application.applicationForm.dialogues.topReferences')"
>
<el-input-number
v-model="form.top_n"
:min="1"
:max="100"
:value-on-clear="1"
controls-position="right"
class="w-full"
/>
</el-form-item>
</el-col>
</el-row>
<el-form-item :label="$t('views.application.applicationForm.dialogues.maxCharacters')">
<el-slider
v-model="form.max_paragraph_char_number"
show-input
:show-input-controls="false"
:min="500"
:max="100000"
class="custom-slider"
/>
</el-form-item>
<el-form-item
v-if="!isWorkflowType"
:label="$t('views.application.applicationForm.dialogues.noReferencesAction')"
>
<el-form
label-position="top"
ref="noReferencesformRef"
:model="noReferencesform"
:rules="noReferencesRules"
class="w-full"
:hide-required-asterisk="true"
>
<el-radio-group
v-model="form.no_references_setting.status"
class="radio-block mb-16"
>
<div>
<el-radio value="ai_questioning">
<p>
{{ $t('views.application.applicationForm.dialogues.continueQuestioning') }}
</p>
<el-form-item
v-if="form.no_references_setting.status === 'ai_questioning'"
:label="$t('views.application.applicationForm.form.prompt.label')"
prop="ai_questioning"
>
<el-input
v-model="noReferencesform.ai_questioning"
:rows="2"
type="textarea"
maxlength="2048"
:placeholder="defaultValue['ai_questioning']"
/>
</el-form-item>
</el-radio>
</div>
<div class="mt-8">
<el-radio value="designated_answer">
<p>{{ $t('views.application.applicationForm.dialogues.provideAnswer') }}</p>
<el-form-item
v-if="form.no_references_setting.status === 'designated_answer'"
prop="designated_answer"
>
<el-input
v-model="noReferencesform.designated_answer"
:rows="2"
type="textarea"
maxlength="2048"
:placeholder="defaultValue['designated_answer']"
/>
</el-form-item>
</el-radio>
</div>
</el-radio-group>
</el-form>
</el-form-item>
</el-form>
</div>
</el-scrollbar>
</div>
<template #footer>
<span class="dialog-footer p-16">
<el-button @click.prevent="dialogVisible = false">{{
$t('views.application.applicationForm.buttons.cancel')
}}</el-button>
<el-button type="primary" @click="submit(noReferencesformRef)" :loading="loading">
{{ $t('views.application.applicationForm.buttons.save') }}
</el-button>
</span>
</template>
</el-dialog>
</template>
<script setup lang="ts">
import { ref, watch, reactive } from 'vue'
import { cloneDeep } from 'lodash'
import type { FormInstance, FormRules } from 'element-plus'
import { isWorkFlow } from '@/utils/application'
import { t } from '@/locales'
const emit = defineEmits(['refresh'])
const paramFormRef = ref()
const noReferencesformRef = ref()
const defaultValue = {
ai_questioning: '{question}',
// @ts-ignore
designated_answer: t('views.application.applicationForm.dialogues.designated_answer')
}
const form = ref<any>({
search_mode: 'embedding',
top_n: 3,
similarity: 0.6,
max_paragraph_char_number: 5000,
no_references_setting: {
status: 'ai_questioning',
value: '{question}'
}
})
const noReferencesform = ref<any>({
ai_questioning: defaultValue['ai_questioning'],
designated_answer: defaultValue['designated_answer']
})
const noReferencesRules = reactive<FormRules<any>>({
ai_questioning: [
{
required: true,
message: t('views.application.applicationForm.dialogues.promptPlaceholder'),
trigger: 'blur'
}
],
designated_answer: [
{
required: true,
message: t('views.application.applicationForm.dialogues.concentPlaceholder'),
trigger: 'blur'
}
]
})
const dialogVisible = ref<boolean>(false)
const loading = ref(false)
const isWorkflowType = ref(false)
watch(dialogVisible, (bool) => {
if (!bool) {
form.value = {
search_mode: 'embedding',
top_n: 3,
similarity: 0.6,
max_paragraph_char_number: 5000,
no_references_setting: {
status: 'ai_questioning',
value: ''
}
}
noReferencesform.value = {
ai_questioning: defaultValue['ai_questioning'],
designated_answer: defaultValue['designated_answer']
}
noReferencesformRef.value?.clearValidate()
}
})
const open = (data: any, type?: string) => {
isWorkflowType.value = isWorkFlow(type)
form.value = { ...form.value, ...cloneDeep(data) }
noReferencesform.value[form.value.no_references_setting.status] =
form.value.no_references_setting.value
dialogVisible.value = true
}
const submit = async (formEl: FormInstance | undefined) => {
if (isWorkflowType.value) {
delete form.value['no_references_setting']
emit('refresh', form.value)
dialogVisible.value = false
} else {
if (!formEl) return
await formEl.validate((valid, fields) => {
if (valid) {
form.value.no_references_setting.value =
noReferencesform.value[form.value.no_references_setting.status]
emit('refresh', form.value)
dialogVisible.value = false
}
})
}
}
function changeHandle(val: string) {
if (val === 'keywords') {
form.value.similarity = 0
} else {
form.value.similarity = 0.6
}
}
defineExpose({ open })
</script>
<style lang="scss" scope>
.param-dialog {
padding: 8px 8px 24px 8px;
.el-dialog__header {
padding: 16px 16px 0 16px;
}
.el-dialog__body {
padding: 0 !important;
}
.dialog-max-height {
height: 550px;
}
.custom-slider {
.el-input-number.is-without-controls .el-input__wrapper {
padding: 0 !important;
}
}
}
</style>

View File

@ -0,0 +1,12 @@
import RerankerNodeVue from './index.vue'
import { AppNode, AppNodeModel } from '@/workflow/common/app-node'
class RerankerNode extends AppNode {
constructor(props: any) {
super(props, RerankerNodeVue)
}
}
export default {
type: 'reranker-node',
model: AppNodeModel,
view: RerankerNode
}

View File

@ -0,0 +1,303 @@
<template>
<NodeContainer :nodeModel="nodeModel">
<el-card shadow="never" class="card-never" style="--el-card-padding: 12px">
<el-form
@submit.prevent
:model="form_data"
label-position="top"
require-asterisk-position="right"
label-width="auto"
ref="rerankerNodeFormRef"
>
<el-form-item label="知识库检索结果">
<template #label>
<div class="flex-between">
<span>知识库检索结果</span>
<el-button @click="add_reranker_reference" link type="primary">
<el-icon class="mr-4"><Plus /></el-icon>
</el-button>
</div>
</template>
<el-row
:gutter="8"
style="margin-bottom: 8px"
v-for="(reranker_reference, index) in form_data.reranker_reference_list"
>
<el-col :span="22">
<el-form-item
:prop="'reranker_reference_list.' + index"
:rules="{
type: 'array',
required: true,
message: '请选择变量',
trigger: 'change'
}"
>
<NodeCascader
:key="index"
ref="nodeCascaderRef"
:nodeModel="nodeModel"
class="w-full"
placeholder="请选择检索问题输入"
v-model="form_data.reranker_reference_list[index]"
/>
</el-form-item>
</el-col>
<el-col :span="1">
<el-button link type="info" class="mt-4" @click="deleteCondition(index)">
<el-icon><Delete /></el-icon>
</el-button>
</el-col>
</el-row>
</el-form-item>
<el-form-item label="检索参数">
<template #label>
<div class="flex-between">
<span>检索参数</span>
<el-button type="primary" link @click="openParamSettingDialog">
<el-icon><Setting /></el-icon>
</el-button>
</div>
</template>
<div class="w-full">
<el-row>
<el-col :span="12" class="color-secondary lighter"> 相似度高于</el-col>
<el-col :span="12" class="lighter">
{{ form_data.reranker_setting.similarity }}</el-col
>
<el-col :span="12" class="color-secondary lighter"> 引用分段 Top</el-col>
<el-col :span="12" class="lighter"> {{ form_data.reranker_setting.top_n }}</el-col>
<el-col :span="12" class="color-secondary lighter"> 最大引用字符数</el-col>
<el-col :span="12" class="lighter">
{{ form_data.reranker_setting.max_paragraph_char_number }}</el-col
>
</el-row>
</div>
</el-form-item>
<el-form-item
label="检索问题"
prop="question_reference_address"
:rules="{
message: '请选择检索问题输入',
trigger: 'blur',
required: true
}"
>
<NodeCascader
ref="nodeCascaderRef"
:nodeModel="nodeModel"
class="w-full"
placeholder="检索问题"
v-model="form_data.question_reference_address"
/>
</el-form-item>
<el-form-item
label="AI 模型"
prop="reranker_model_id"
:rules="{
required: true,
message: '请选择 AI 模型',
trigger: 'change'
}"
>
<el-select
@wheel="wheel"
:teleported="false"
v-model="form_data.reranker_model_id"
placeholder="请选择 AI 模型"
class="w-full"
popper-class="select-model"
:clearable="true"
>
<el-option-group
v-for="(value, label) in modelOptions"
:key="value"
:label="relatedObject(providerOptions, label, 'provider')?.name"
>
<el-option
v-for="item in value.filter((v: any) => v.status === 'SUCCESS')"
:key="item.id"
:label="item.name"
:value="item.id"
class="flex-between"
>
<div class="flex align-center">
<span
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
class="model-icon mr-8"
></span>
<span>{{ item.name }}</span>
<el-tag v-if="item.permission_type === 'PUBLIC'" type="info" class="info-tag ml-8"
>公用
</el-tag>
</div>
<el-icon class="check-icon" v-if="item.id === form_data.model_id"
><Check
/></el-icon>
</el-option>
<el-option
v-for="item in value.filter((v: any) => v.status !== 'SUCCESS')"
:key="item.id"
:label="item.name"
:value="item.id"
class="flex-between"
disabled
>
<div class="flex">
<span
v-html="relatedObject(providerOptions, label, 'provider')?.icon"
class="model-icon mr-8"
></span>
<span>{{ item.name }}</span>
<span class="danger">不可用</span>
</div>
<el-icon class="check-icon" v-if="item.id === form_data.model_id"
><Check
/></el-icon>
</el-option>
</el-option-group>
<template #footer>
<div class="w-full text-left cursor" @click="openCreateModel()">
<el-button type="primary" link>
<el-icon class="mr-4"><Plus /></el-icon>
添加模型
</el-button>
</div>
</template>
</el-select>
</el-form-item>
</el-form>
</el-card>
<ParamSettingDialog ref="ParamSettingDialogRef" @refresh="refreshParam" />
<!-- 添加模版 -->
<CreateModelDialog
ref="createModelRef"
@submit="getModel"
@change="openCreateModel($event)"
></CreateModelDialog>
</NodeContainer>
</template>
<script setup lang="ts">
import { set, cloneDeep, groupBy } from 'lodash'
import NodeContainer from '@/workflow/common/NodeContainer.vue'
import NodeCascader from '@/workflow/common/NodeCascader.vue'
import ParamSettingDialog from './ParamSettingDialog.vue'
import { ref, computed, onMounted } from 'vue'
import { isLastNode } from '@/workflow/common/data'
import type { Provider } from '@/api/type/model'
import CreateModelDialog from '@/views/template/component/CreateModelDialog.vue'
import SelectProviderDialog from '@/views/template/component/SelectProviderDialog.vue'
import applicationApi from '@/api/application'
import useStore from '@/stores'
import { app } from '@/main'
import { relatedObject } from '@/utils/utils'
const { model } = useStore()
const props = defineProps<{ nodeModel: any }>()
const createModelRef = ref<InstanceType<typeof CreateModelDialog>>()
const selectProviderRef = ref<InstanceType<typeof SelectProviderDialog>>()
const ParamSettingDialogRef = ref<InstanceType<typeof ParamSettingDialog>>()
const {
params: { id }
} = app.config.globalProperties.$route as any
const form = {
reranker_reference_list: [],
reranker_model_id: '',
reranker_setting: {
top_n: 3,
similarity: 0.6,
max_paragraph_char_number: 5000
}
}
const providerOptions = ref<Array<Provider>>([])
const modelOptions = ref<any>(null)
const openParamSettingDialog = () => {
ParamSettingDialogRef.value?.open(form_data.value.dataset_setting, 'WORK_FLOW')
}
const deleteCondition = (index: number) => {
const list = cloneDeep(props.nodeModel.properties.node_data.reranker_reference_list)
list.splice(index, 1)
set(props.nodeModel.properties.node_data, 'reranker_reference_list', list)
}
const wheel = (e: any) => {
if (e.ctrlKey === true) {
e.preventDefault()
return true
} else {
e.stopPropagation()
return true
}
}
const form_data = computed({
get: () => {
if (props.nodeModel.properties.node_data) {
return props.nodeModel.properties.node_data
} else {
set(props.nodeModel.properties, 'node_data', form)
}
return props.nodeModel.properties.node_data
},
set: (value) => {
set(props.nodeModel.properties, 'node_data', value)
}
})
function refreshParam(data: any) {
set(props.nodeModel.properties.node_data, 'reranker_setting', data)
}
function getModel() {
if (id) {
applicationApi.getApplicationRerankerModel(id).then((res: any) => {
modelOptions.value = groupBy(res?.data, 'provider')
})
} else {
model.asyncGetModel({ model_type: 'RERANKER' }).then((res: any) => {
modelOptions.value = groupBy(res?.data, 'provider')
})
}
}
function getProvider() {
model.asyncGetProvider().then((res: any) => {
providerOptions.value = res?.data
})
}
const add_reranker_reference = () => {
const list = cloneDeep(props.nodeModel.properties.node_data.reranker_reference_list)
list.push([])
set(props.nodeModel.properties.node_data, 'reranker_reference_list', list)
}
const rerankerNodeFormRef = ref()
const nodeCascaderRef = ref()
const validate = () => {
return Promise.all([
nodeCascaderRef.value ? nodeCascaderRef.value.validate() : Promise.resolve(''),
rerankerNodeFormRef.value?.validate()
]).catch((err: any) => {
return Promise.reject({ node: props.nodeModel, errMessage: err })
})
}
const openCreateModel = (provider?: Provider) => {
if (provider && provider.provider) {
createModelRef.value?.open(provider)
} else {
selectProviderRef.value?.open()
}
}
onMounted(() => {
getProvider()
getModel()
if (typeof props.nodeModel.properties.node_data?.is_result === 'undefined') {
if (isLastNode(props.nodeModel)) {
set(props.nodeModel.properties.node_data, 'is_result', true)
}
}
set(props.nodeModel, 'validate', validate)
})
</script>
<style lang="scss" scoped>
.reply-node-editor {
:deep(.md-editor-footer) {
border: none !important;
}
}
</style>