mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
Merge branch 'main' of https://github.com/1Panel-dev/MaxKB
This commit is contained in:
commit
93a959b3ef
|
|
@ -54,7 +54,7 @@ class IChatStep(IBaseChatPipelineStep):
|
|||
message_list = serializers.ListField(required=True, child=MessageField(required=True),
|
||||
error_messages=ErrMessage.list("对话列表"))
|
||||
# 大语言模型
|
||||
chat_model = ModelField(error_messages=ErrMessage.list("大语言模型"))
|
||||
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型"))
|
||||
# 段落列表
|
||||
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
|
||||
# 对话id
|
||||
|
|
|
|||
|
|
@ -59,8 +59,12 @@ def event_content(response,
|
|||
|
||||
# 获取token
|
||||
if is_ai_chat:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(all_text)
|
||||
try:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(all_text)
|
||||
except Exception as e:
|
||||
request_token = 0
|
||||
response_token = 0
|
||||
else:
|
||||
request_token = 0
|
||||
response_token = 0
|
||||
|
|
@ -126,6 +130,26 @@ class BaseChatStep(IChatStep):
|
|||
result.append({'role': 'ai', 'content': answer_text})
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_stream_result(message_list: List[BaseMessage],
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
no_references_setting=None):
|
||||
if paragraph_list is None:
|
||||
paragraph_list = []
|
||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
|
||||
for paragraph in paragraph_list if
|
||||
paragraph.hit_handling_method == 'directly_return']
|
||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||
return iter(directly_return_chunk_list), False
|
||||
elif len(paragraph_list) == 0 and no_references_setting.get(
|
||||
'status') == 'designated_answer':
|
||||
return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False
|
||||
if chat_model is None:
|
||||
return iter([AIMessageChunk('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。')]), False
|
||||
else:
|
||||
return chat_model.stream(message_list), True
|
||||
|
||||
def execute_stream(self, message_list: List[BaseMessage],
|
||||
chat_id,
|
||||
problem_text,
|
||||
|
|
@ -136,29 +160,8 @@ class BaseChatStep(IChatStep):
|
|||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None,
|
||||
no_references_setting=None):
|
||||
is_ai_chat = False
|
||||
# 调用模型
|
||||
if chat_model is None:
|
||||
chat_result = iter(
|
||||
[AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
|
||||
else:
|
||||
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
|
||||
'status') == 'designated_answer':
|
||||
chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))])
|
||||
else:
|
||||
if paragraph_list is not None and len(paragraph_list) > 0:
|
||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
|
||||
for paragraph in paragraph_list if
|
||||
paragraph.hit_handling_method == 'directly_return']
|
||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||
chat_result = iter(directly_return_chunk_list)
|
||||
else:
|
||||
chat_result = chat_model.stream(message_list)
|
||||
is_ai_chat = True
|
||||
else:
|
||||
chat_result = chat_model.stream(message_list)
|
||||
is_ai_chat = True
|
||||
|
||||
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
|
||||
no_references_setting)
|
||||
chat_record_id = uuid.uuid1()
|
||||
r = StreamingHttpResponse(
|
||||
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
|
||||
|
|
@ -169,6 +172,27 @@ class BaseChatStep(IChatStep):
|
|||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def get_block_result(message_list: List[BaseMessage],
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
no_references_setting=None):
|
||||
if paragraph_list is None:
|
||||
paragraph_list = []
|
||||
|
||||
directly_return_chunk_list = [AIMessage(content=paragraph.content)
|
||||
for paragraph in paragraph_list if
|
||||
paragraph.hit_handling_method == 'directly_return']
|
||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||
return directly_return_chunk_list[0], False
|
||||
elif len(paragraph_list) == 0 and no_references_setting.get(
|
||||
'status') == 'designated_answer':
|
||||
return AIMessage(no_references_setting.get('value')), False
|
||||
if chat_model is None:
|
||||
return AIMessage('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。'), False
|
||||
else:
|
||||
return chat_model.invoke(message_list), True
|
||||
|
||||
def execute_block(self, message_list: List[BaseMessage],
|
||||
chat_id,
|
||||
problem_text,
|
||||
|
|
@ -178,28 +202,8 @@ class BaseChatStep(IChatStep):
|
|||
manage: PipelineManage = None,
|
||||
padding_problem_text: str = None,
|
||||
client_id=None, client_type=None, no_references_setting=None):
|
||||
is_ai_chat = False
|
||||
# 调用模型
|
||||
if chat_model is None:
|
||||
chat_result = AIMessage(
|
||||
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
|
||||
else:
|
||||
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
|
||||
'status') == 'designated_answer':
|
||||
chat_result = AIMessage(content=no_references_setting.get('value'))
|
||||
else:
|
||||
if paragraph_list is not None and len(paragraph_list) > 0:
|
||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
|
||||
for paragraph in paragraph_list if
|
||||
paragraph.hit_handling_method == 'directly_return']
|
||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||
chat_result = iter(directly_return_chunk_list)
|
||||
else:
|
||||
chat_result = chat_model.invoke(message_list)
|
||||
is_ai_chat = True
|
||||
else:
|
||||
chat_result = chat_model.invoke(message_list)
|
||||
is_ai_chat = True
|
||||
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, no_references_setting)
|
||||
chat_record_id = uuid.uuid1()
|
||||
if is_ai_chat:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
|
|||
padding_problem_text: str = None,
|
||||
no_references_setting=None,
|
||||
**kwargs) -> List[BaseMessage]:
|
||||
prompt = prompt if no_references_setting.get('status') == 'designated_answer' else no_references_setting.get(
|
||||
prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get(
|
||||
'value')
|
||||
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
|
||||
start_index = len(history_chat_record) - dialogue_number
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class IResetProblemStep(IBaseChatPipelineStep):
|
|||
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
|
||||
error_messages=ErrMessage.list("历史对答"))
|
||||
# 大语言模型
|
||||
chat_model = ModelField(error_messages=ErrMessage.base("大语言模型"))
|
||||
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型"))
|
||||
|
||||
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
|
||||
return self.InstanceSerializer
|
||||
|
|
|
|||
|
|
@ -22,6 +22,10 @@ prompt = (
|
|||
class BaseResetProblemStep(IResetProblemStep):
|
||||
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
|
||||
**kwargs) -> str:
|
||||
if chat_model is None:
|
||||
self.context['message_tokens'] = 0
|
||||
self.context['answer_tokens'] = 0
|
||||
return problem_text
|
||||
start_index = len(history_chat_record) - 3
|
||||
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
|
||||
for index in
|
||||
|
|
@ -35,8 +39,14 @@ class BaseResetProblemStep(IResetProblemStep):
|
|||
response.content.index('<data>') + 6:response.content.index('</data>')]
|
||||
if padding_problem_data is not None and len(padding_problem_data.strip()) > 0:
|
||||
padding_problem = padding_problem_data
|
||||
self.context['message_tokens'] = chat_model.get_num_tokens_from_messages(message_list)
|
||||
self.context['answer_tokens'] = chat_model.get_num_tokens(padding_problem)
|
||||
try:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(padding_problem)
|
||||
except Exception as e:
|
||||
request_token = 0
|
||||
response_token = 0
|
||||
self.context['message_tokens'] = request_token
|
||||
self.context['answer_tokens'] = response_token
|
||||
return padding_problem
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
# Generated by Django 4.1.13 on 2024-04-29 13:33
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0004_applicationaccesstoken_show_source'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='chat',
|
||||
name='abstract',
|
||||
field=models.CharField(max_length=1024, verbose_name='摘要'),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name='chatrecord',
|
||||
name='answer_text',
|
||||
field=models.CharField(max_length=40960, verbose_name='答案'),
|
||||
),
|
||||
]
|
||||
|
|
@ -73,7 +73,7 @@ class ApplicationDatasetMapping(AppModelMixin):
|
|||
class Chat(AppModelMixin):
|
||||
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
|
||||
application = models.ForeignKey(Application, on_delete=models.CASCADE)
|
||||
abstract = models.CharField(max_length=256, verbose_name="摘要")
|
||||
abstract = models.CharField(max_length=1024, verbose_name="摘要")
|
||||
client_id = models.UUIDField(verbose_name="客户端id", default=None, null=True)
|
||||
|
||||
class Meta:
|
||||
|
|
@ -96,7 +96,7 @@ class ChatRecord(AppModelMixin):
|
|||
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
|
||||
default=VoteChoices.UN_VOTE)
|
||||
problem_text = models.CharField(max_length=1024, verbose_name="问题")
|
||||
answer_text = models.CharField(max_length=4096, verbose_name="答案")
|
||||
answer_text = models.CharField(max_length=40960, verbose_name="答案")
|
||||
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
|
||||
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
|
||||
const = models.IntegerField(verbose_name="总费用", default=0)
|
||||
|
|
|
|||
|
|
@ -47,7 +47,8 @@ chat_cache = cache.caches['chat_cache']
|
|||
|
||||
class ModelDatasetAssociation(serializers.Serializer):
|
||||
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
|
||||
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.char("模型id"))
|
||||
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
|
||||
error_messages=ErrMessage.uuid(
|
||||
"知识库id")),
|
||||
|
|
@ -57,8 +58,9 @@ class ModelDatasetAssociation(serializers.Serializer):
|
|||
super().is_valid(raise_exception=True)
|
||||
model_id = self.data.get('model_id')
|
||||
user_id = self.data.get('user_id')
|
||||
if not QuerySet(Model).filter(id=model_id).exists():
|
||||
raise AppApiException(500, f'模型不存在【{model_id}】')
|
||||
if model_id is not None and len(model_id) > 0:
|
||||
if not QuerySet(Model).filter(id=model_id).exists():
|
||||
raise AppApiException(500, f'模型不存在【{model_id}】')
|
||||
dataset_id_list = list(set(self.data.get('dataset_id_list')))
|
||||
exist_dataset_id_list = [str(dataset.id) for dataset in
|
||||
QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)]
|
||||
|
|
@ -109,7 +111,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
max_length=256, min_length=1,
|
||||
error_messages=ErrMessage.char("应用描述"))
|
||||
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型"))
|
||||
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.char("模型"))
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
|
||||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
|
||||
error_messages=ErrMessage.char("开场白"))
|
||||
|
|
@ -254,7 +257,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
error_messages=ErrMessage.char("应用名称"))
|
||||
desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.char("应用描述"))
|
||||
model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型"))
|
||||
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True,
|
||||
error_messages=ErrMessage.char("模型"))
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=False,
|
||||
error_messages=ErrMessage.boolean("多轮会话"))
|
||||
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
|
||||
|
|
@ -494,22 +498,21 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
application_id = self.data.get("application_id")
|
||||
|
||||
application = QuerySet(Application).get(id=application_id)
|
||||
|
||||
model = QuerySet(Model).filter(
|
||||
id=instance.get('model_id') if 'model_id' in instance else application.model_id,
|
||||
user_id=application.user_id).first()
|
||||
if model is None:
|
||||
raise AppApiException(500, "模型不存在")
|
||||
|
||||
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
|
||||
application.model_id = None
|
||||
else:
|
||||
model = QuerySet(Model).filter(
|
||||
id=instance.get('model_id'),
|
||||
user_id=application.user_id).first()
|
||||
if model is None:
|
||||
raise AppApiException(500, "模型不存在")
|
||||
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
|
||||
'dataset_setting', 'model_setting', 'problem_optimization',
|
||||
'api_key_is_active', 'icon']
|
||||
for update_key in update_keys:
|
||||
if update_key in instance and instance.get(update_key) is not None:
|
||||
if update_key == 'multiple_rounds_dialogue':
|
||||
application.__setattr__('dialogue_number',
|
||||
0 if not instance.get(update_key) else ModelProvideConstants[
|
||||
model.provider].value.get_dialogue_number())
|
||||
application.__setattr__('dialogue_number', 0 if not instance.get(update_key) else 3)
|
||||
else:
|
||||
application.__setattr__(update_key, instance.get(update_key))
|
||||
application.save()
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from application.models.api_key_model import ApplicationPublicAccessClient, Appl
|
|||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.rsa_util import decrypt
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from common.util.split_model import flat_map
|
||||
from dataset.models import Paragraph, Document
|
||||
from setting.models import Model, Status
|
||||
|
|
@ -138,7 +138,7 @@ def get_post_handler(chat_info: ChatInfo):
|
|||
|
||||
class ChatMessageSerializer(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id"))
|
||||
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"))
|
||||
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"), max_length=1024)
|
||||
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答"))
|
||||
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
|
||||
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
|
||||
|
|
@ -167,9 +167,11 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
chat_cache.set(chat_id,
|
||||
chat_info, timeout=60 * 30)
|
||||
model = chat_info.application.model
|
||||
if model is None:
|
||||
return chat_info
|
||||
model = QuerySet(Model).filter(id=model.id).first()
|
||||
if model is None:
|
||||
raise AppApiException(500, "模型不存在")
|
||||
return chat_info
|
||||
if model.status == Status.ERROR:
|
||||
raise AppApiException(500, "当前模型不可用")
|
||||
if model.status == Status.DOWNLOAD:
|
||||
|
|
@ -223,7 +225,7 @@ class ChatMessageSerializer(serializers.Serializer):
|
|||
# 对话模型
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
rsa_long_decrypt(model.credential)),
|
||||
streaming=True)
|
||||
# 数据集id列表
|
||||
dataset_id_list = [str(row.dataset_id) for row in
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ from common.util.common import post
|
|||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from common.util.rsa_util import decrypt
|
||||
from common.util.rsa_util import rsa_long_decrypt
|
||||
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||
from setting.models import Model
|
||||
|
|
@ -195,7 +195,8 @@ class ChatSerializers(serializers.Serializer):
|
|||
if model is not None:
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
rsa_long_decrypt(
|
||||
model.credential)),
|
||||
streaming=True)
|
||||
|
||||
chat_id = str(uuid.uuid1())
|
||||
|
|
@ -213,7 +214,8 @@ class ChatSerializers(serializers.Serializer):
|
|||
|
||||
id = serializers.UUIDField(required=False, allow_null=True,
|
||||
error_messages=ErrMessage.uuid("应用id"))
|
||||
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
|
||||
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
|
||||
error_messages=ErrMessage.uuid("模型id"))
|
||||
|
||||
multiple_rounds_dialogue = serializers.BooleanField(required=True,
|
||||
error_messages=ErrMessage.boolean("多轮会话"))
|
||||
|
|
@ -246,14 +248,18 @@ class ChatSerializers(serializers.Serializer):
|
|||
def open(self):
|
||||
user_id = self.is_valid(raise_exception=True)
|
||||
chat_id = str(uuid.uuid1())
|
||||
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
|
||||
if model is None:
|
||||
raise AppApiException(500, "模型不存在")
|
||||
model_id = self.data.get('model_id')
|
||||
if model_id is not None and len(model_id) > 0:
|
||||
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(
|
||||
model.credential)),
|
||||
streaming=True)
|
||||
else:
|
||||
model = None
|
||||
chat_model = None
|
||||
dataset_id_list = self.data.get('dataset_id_list')
|
||||
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
decrypt(model.credential)),
|
||||
streaming=True)
|
||||
application = Application(id=None, dialogue_number=3, model=model,
|
||||
dataset_setting=self.data.get('dataset_setting'),
|
||||
model_setting=self.data.get('model_setting'),
|
||||
|
|
|
|||
|
|
@ -62,18 +62,6 @@ def get_key_pair_by_sql():
|
|||
return system_setting.meta
|
||||
|
||||
|
||||
# def get_key_pair():
|
||||
# if not os.path.exists("/opt/maxkb/conf/receiver.pem"):
|
||||
# kv = generate()
|
||||
# private_file_out = open("/opt/maxkb/conf/private.pem", "wb")
|
||||
# private_file_out.write(kv.get('value'))
|
||||
# private_file_out.close()
|
||||
# receiver_file_out = open("/opt/maxkb/conf/receiver.pem", "wb")
|
||||
# receiver_file_out.write(kv.get('key'))
|
||||
# receiver_file_out.close()
|
||||
# return {'key': open("/opt/maxkb/conf/receiver.pem").read(), 'value': open("/opt/maxkb/conf/private.pem").read()}
|
||||
|
||||
|
||||
def encrypt(msg, public_key: str | None = None):
|
||||
"""
|
||||
加密
|
||||
|
|
@ -100,3 +88,53 @@ def decrypt(msg, pri_key: str | None = None):
|
|||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
|
||||
return decrypt_data.decode("utf-8")
|
||||
|
||||
|
||||
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
|
||||
"""
|
||||
超长文本加密
|
||||
|
||||
:param message: 需要加密的字符串
|
||||
:param public_key 公钥
|
||||
:param length: 1024bit的证书用100, 2048bit的证书用 200
|
||||
:return: 加密后的数据
|
||||
"""
|
||||
# 读取公钥
|
||||
if public_key is None:
|
||||
public_key = get_key_pair().get('key')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
|
||||
passphrase=secret_code))
|
||||
# 处理:Plaintext is too long. 分段加密
|
||||
if len(message) <= length:
|
||||
# 对编码的数据进行加密,并通过base64进行编码
|
||||
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
|
||||
else:
|
||||
rsa_text = []
|
||||
# 对编码后的数据进行切片,原因:加密长度不能过长
|
||||
for i in range(0, len(message), length):
|
||||
cont = message[i:i + length]
|
||||
# 对切片后的数据进行加密,并新增到text后面
|
||||
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
|
||||
# 加密完进行拼接
|
||||
cipher_text = b''.join(rsa_text)
|
||||
# base64进行编码
|
||||
result = base64.b64encode(cipher_text)
|
||||
return result.decode()
|
||||
|
||||
|
||||
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
|
||||
"""
|
||||
超长文本解密,默认不加密
|
||||
:param message: 需要解密的数据
|
||||
:param pri_key: 秘钥
|
||||
:param length : 1024bit的证书用128,2048bit证书用256位
|
||||
:return: 解密后的数据
|
||||
"""
|
||||
if pri_key is None:
|
||||
pri_key = get_key_pair().get('value')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
base64_de = base64.b64decode(message)
|
||||
res = []
|
||||
for i in range(0, len(base64_de), length):
|
||||
res.append(cipher.decrypt(base64_de[i:i + length], 0))
|
||||
return b"".join(res).decode()
|
||||
|
|
|
|||
|
|
@ -164,6 +164,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
elif target_dataset.type == Type.base.value and dataset.type == Type.web.value:
|
||||
document_list.update(dataset_id=target_dataset_id, type=Type.base,
|
||||
meta={})
|
||||
else:
|
||||
document_list.update(dataset_id=target_dataset_id)
|
||||
paragraph_list.update(dataset_id=target_dataset_id)
|
||||
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
|
||||
[problem_paragraph_mapping.id for problem_paragraph_mapping in problem_paragraph_mapping_list],
|
||||
|
|
@ -713,6 +715,19 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
|
|||
ListenerManagement.delete_embedding_by_document_list_signal.send(document_id_list)
|
||||
return True
|
||||
|
||||
def batch_edit_hit_handling(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
|
||||
hit_handling_method = instance.get('hit_handling_method')
|
||||
if hit_handling_method is None:
|
||||
raise AppApiException(500, '命中处理方式必填')
|
||||
if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return':
|
||||
raise AppApiException(500, '命中处理方式必须为directly_return|optimization')
|
||||
self.is_valid(raise_exception=True)
|
||||
document_id_list = instance.get("id_list")
|
||||
hit_handling_method = instance.get('hit_handling_method')
|
||||
QuerySet(Document).filter(id__in=document_id_list).update(hit_handling_method=hit_handling_method)
|
||||
|
||||
|
||||
class FileBufferHandle:
|
||||
buffer = None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: document_api.py
|
||||
@date:2024/4/28 13:56
|
||||
@desc:
|
||||
"""
|
||||
from drf_yasg import openapi
|
||||
|
||||
from common.mixins.api_mixin import ApiMixin
|
||||
|
||||
|
||||
class DocumentApi(ApiMixin):
|
||||
class BatchEditHitHandlingApi(ApiMixin):
|
||||
@staticmethod
|
||||
def get_request_body_api():
|
||||
return openapi.Schema(
|
||||
type=openapi.TYPE_OBJECT,
|
||||
properties={
|
||||
'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
|
||||
title="主键id列表",
|
||||
description="主键id列表"),
|
||||
'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式",
|
||||
description="directly_return|optimization")
|
||||
}
|
||||
)
|
||||
|
|
@ -14,6 +14,7 @@ urlpatterns = [
|
|||
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
|
||||
path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),
|
||||
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
|
||||
name="document_operate"),
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from common.response import result
|
|||
from common.util.common import query_params_to_single_dict
|
||||
from dataset.serializers.common_serializers import BatchSerializer
|
||||
from dataset.serializers.document_serializers import DocumentSerializers, DocumentWebInstanceSerializer
|
||||
from dataset.swagger_api.document_api import DocumentApi
|
||||
|
||||
|
||||
class WebDocument(APIView):
|
||||
|
|
@ -71,6 +72,24 @@ class Document(APIView):
|
|||
d.is_valid(raise_exception=True)
|
||||
return result.success(d.list())
|
||||
|
||||
class BatchEditHitHandling(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
@action(methods=['POST'], detail=False)
|
||||
@swagger_auto_schema(operation_summary="批量修改文档命中处理方式",
|
||||
operation_id="批量修改文档命中处理方式",
|
||||
request_body=
|
||||
DocumentApi.BatchEditHitHandlingApi.get_request_body_api(),
|
||||
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
|
||||
responses=result.get_default_response(),
|
||||
tags=["知识库/文档"])
|
||||
@has_permissions(
|
||||
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
|
||||
dynamic_tag=k.get('dataset_id')))
|
||||
def put(self, request: Request, dataset_id: str):
|
||||
return result.success(
|
||||
DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_edit_hit_handling(request.data))
|
||||
|
||||
class Batch(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
|
|
|
|||
|
|
@ -18,27 +18,30 @@ def update_embedding_search_vector(embedding, paragraph_list):
|
|||
|
||||
|
||||
def save_keywords(apps, schema_editor):
|
||||
document = apps.get_model("dataset", "Document")
|
||||
embedding = apps.get_model("embedding", "Embedding")
|
||||
paragraph = apps.get_model('dataset', 'Paragraph')
|
||||
db_alias = schema_editor.connection.alias
|
||||
document_list = document.objects.using(db_alias).all()
|
||||
for document in document_list:
|
||||
document.status = Status.embedding
|
||||
document.save()
|
||||
paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all()
|
||||
embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector',
|
||||
'paragraph')
|
||||
embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding
|
||||
in embedding_list]
|
||||
child_array = sub_array(embedding_update_list, 50)
|
||||
for c in child_array:
|
||||
try:
|
||||
embedding.objects.using(db_alias).bulk_update(c, ['search_vector'])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
document.status = Status.success
|
||||
document.save()
|
||||
try:
|
||||
document = apps.get_model("dataset", "Document")
|
||||
embedding = apps.get_model("embedding", "Embedding")
|
||||
paragraph = apps.get_model('dataset', 'Paragraph')
|
||||
db_alias = schema_editor.connection.alias
|
||||
document_list = document.objects.using(db_alias).all()
|
||||
for document in document_list:
|
||||
document.status = Status.embedding
|
||||
document.save()
|
||||
paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all()
|
||||
embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector',
|
||||
'paragraph')
|
||||
embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding
|
||||
in embedding_list]
|
||||
child_array = sub_array(embedding_update_list, 50)
|
||||
for c in child_array:
|
||||
try:
|
||||
embedding.objects.using(db_alias).bulk_update(c, ['search_vector'])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
document.status = Status.success
|
||||
document.save()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.1.13 on 2024-04-28 18:06
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('setting', '0003_model_meta_model_status'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='model',
|
||||
name='credential',
|
||||
field=models.CharField(max_length=102400, verbose_name='模型认证信息'),
|
||||
),
|
||||
]
|
||||
|
|
@ -42,7 +42,7 @@ class Model(AppModelMixin):
|
|||
|
||||
provider = models.CharField(max_length=128, verbose_name='供应商')
|
||||
|
||||
credential = models.CharField(max_length=5120, verbose_name="模型认证信息")
|
||||
credential = models.CharField(max_length=102400, verbose_name="模型认证信息")
|
||||
|
||||
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from rest_framework import serializers
|
|||
from application.models import Application
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.rsa_util import encrypt, decrypt
|
||||
from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt
|
||||
from setting.models.model_management import Model, Status
|
||||
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
|
@ -118,7 +118,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
|
||||
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
|
||||
model_name)
|
||||
source_model_credential = json.loads(decrypt(model.credential))
|
||||
source_model_credential = json.loads(rsa_long_decrypt(model.credential))
|
||||
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
|
||||
if credential is not None:
|
||||
for k in source_encryption_model_credential.keys():
|
||||
|
|
@ -170,7 +170,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
model_name = self.data.get('model_name')
|
||||
model_credential_str = json.dumps(credential)
|
||||
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
|
||||
credential=encrypt(model_credential_str),
|
||||
credential=rsa_long_encrypt(model_credential_str),
|
||||
provider=provider, model_type=model_type, model_name=model_name)
|
||||
model.save()
|
||||
if status == Status.DOWNLOAD:
|
||||
|
|
@ -180,7 +180,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
|
||||
@staticmethod
|
||||
def model_to_dict(model: Model):
|
||||
credential = json.loads(decrypt(model.credential))
|
||||
credential = json.loads(rsa_long_decrypt(model.credential))
|
||||
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
|
||||
'model_name': model.model_name,
|
||||
'status': model.status,
|
||||
|
|
@ -252,7 +252,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
if update_key in instance and instance.get(update_key) is not None:
|
||||
if update_key == 'credential':
|
||||
model_credential_str = json.dumps(credential)
|
||||
model.__setattr__(update_key, encrypt(model_credential_str))
|
||||
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
|
||||
else:
|
||||
model.__setattr__(update_key, instance.get(update_key))
|
||||
model.save()
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ const postDataset: (data: datasetData, loading?: Ref<boolean>) => Promise<Result
|
|||
data,
|
||||
loading
|
||||
) => {
|
||||
return post(`${prefix}`, data, undefined, loading)
|
||||
return post(`${prefix}`, data, undefined, loading, 1000 * 60 * 5)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ const prefix = '/dataset'
|
|||
* @param 参数 file:file,limit:number,patterns:array,with_filter:boolean
|
||||
*/
|
||||
const postSplitDocument: (data: any) => Promise<Result<any>> = (data) => {
|
||||
return post(`${prefix}/document/split`, data)
|
||||
return post(`${prefix}/document/split`, data, undefined, undefined, 1000 * 60 * 60)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -80,7 +80,7 @@ const postDocument: (
|
|||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<any>> = (dataset_id, data, loading) => {
|
||||
return post(`${prefix}/${dataset_id}/document/_bach`, data, {}, loading)
|
||||
return post(`${prefix}/${dataset_id}/document/_bach`, data, {}, loading, 1000 * 60 * 5)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -206,6 +206,20 @@ const putMigrateMulDocument: (
|
|||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量修改命中方式
|
||||
* @param dataset_id 知识库id
|
||||
* @param data {id_list:[],hit_handling_method:'directly_return|optimization'}
|
||||
* @param loading
|
||||
* @returns
|
||||
*/
|
||||
const batchEditHitHandling: (
|
||||
dataset_id: string,
|
||||
data: any,
|
||||
loading?: Ref<boolean>
|
||||
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
|
||||
return put(`${prefix}/${dataset_id}/document/batch_hit_handling`, data, undefined, loading)
|
||||
}
|
||||
export default {
|
||||
postSplitDocument,
|
||||
getDocument,
|
||||
|
|
@ -219,5 +233,6 @@ export default {
|
|||
putDocumentRefresh,
|
||||
delMulSyncDocument,
|
||||
postWebDocument,
|
||||
putMigrateMulDocument
|
||||
putMigrateMulDocument,
|
||||
batchEditHitHandling
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ export class ChatRecordManage {
|
|||
this.chat.answer_text =
|
||||
this.chat.answer_text + this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('')
|
||||
} else if (this.is_close) {
|
||||
this.chat.answer_text = this.chat.answer_text + this.chat.buffer.join('')
|
||||
this.chat.answer_text = this.chat.answer_text + this.chat.buffer.splice(0).join('')
|
||||
this.chat.write_ed = true
|
||||
this.write_ed = true
|
||||
if (this.loading) {
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
v-model="dialogVisible"
|
||||
destroy-on-close
|
||||
append-to-body
|
||||
align-center
|
||||
>
|
||||
<div class="paragraph-source-height">
|
||||
<el-scrollbar>
|
||||
|
|
@ -63,7 +64,7 @@
|
|||
</el-dialog>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { ref, watch, nextTick } from 'vue'
|
||||
import { ref, watch, onBeforeUnmount } from 'vue'
|
||||
import { cloneDeep } from 'lodash'
|
||||
import { arraySort } from '@/utils/utils'
|
||||
const emit = defineEmits(['refresh'])
|
||||
|
|
@ -86,12 +87,15 @@ const open = (data: any, id?: string) => {
|
|||
detail.value.paragraph_list = arraySort(detail.value.paragraph_list, 'similarity', true)
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
onBeforeUnmount(() => {
|
||||
dialogVisible.value = false
|
||||
})
|
||||
defineExpose({ open })
|
||||
</script>
|
||||
<style lang="scss">
|
||||
.paragraph-source {
|
||||
padding: 0;
|
||||
|
||||
.el-dialog__header {
|
||||
padding: 24px 24px 0 24px;
|
||||
}
|
||||
|
|
@ -102,4 +106,9 @@ defineExpose({ open })
|
|||
height: calc(100vh - 260px);
|
||||
}
|
||||
}
|
||||
@media only screen and (max-width: 768px) {
|
||||
.paragraph-source {
|
||||
width: 90% !important;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
|
|
|
|||
|
|
@ -224,7 +224,7 @@ const chartOpenId = ref('')
|
|||
const chatList = ref<any[]>([])
|
||||
|
||||
const isDisabledChart = computed(
|
||||
() => !(inputValue.value.trim() && (props.appId || (props.data?.name && props.data?.model_id)))
|
||||
() => !(inputValue.value.trim() && (props.appId || props.data?.name))
|
||||
)
|
||||
const isMdArray = (val: string) => val.match(/^-\s.*/m)
|
||||
const prologueList = computed(() => {
|
||||
|
|
@ -274,12 +274,7 @@ function openParagraph(row: any, id?: string) {
|
|||
}
|
||||
|
||||
function quickProblemHandle(val: string) {
|
||||
if (!props.log && !loading.value && props.data?.name && props.data?.model_id) {
|
||||
// inputValue.value = val
|
||||
// nextTick(() => {
|
||||
// quickInputRef.value?.focus()
|
||||
// })
|
||||
|
||||
if (!loading.value && props.data?.name) {
|
||||
handleDebounceClick(val)
|
||||
}
|
||||
}
|
||||
|
|
@ -509,16 +504,14 @@ function regenerationChart(item: chatType) {
|
|||
}
|
||||
|
||||
function getSourceDetail(row: any) {
|
||||
logApi
|
||||
.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading)
|
||||
.then((res) => {
|
||||
const exclude_keys = ['answer_text', 'id']
|
||||
Object.keys(res.data).forEach((key) => {
|
||||
if (!exclude_keys.includes(key)) {
|
||||
row[key] = res.data[key]
|
||||
}
|
||||
})
|
||||
logApi.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading).then((res) => {
|
||||
const exclude_keys = ['answer_text', 'id']
|
||||
Object.keys(res.data).forEach((key) => {
|
||||
if (!exclude_keys.includes(key)) {
|
||||
row[key] = res.data[key]
|
||||
}
|
||||
})
|
||||
})
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -119,9 +119,15 @@ const promise: (
|
|||
export const get: (
|
||||
url: string,
|
||||
params?: unknown,
|
||||
loading?: NProgress | Ref<boolean>
|
||||
) => Promise<Result<any>> = (url: string, params: unknown, loading?: NProgress | Ref<boolean>) => {
|
||||
return promise(request({ url: url, method: 'get', params }), loading)
|
||||
loading?: NProgress | Ref<boolean>,
|
||||
timeout?: number
|
||||
) => Promise<Result<any>> = (
|
||||
url: string,
|
||||
params: unknown,
|
||||
loading?: NProgress | Ref<boolean>,
|
||||
timeout?: number
|
||||
) => {
|
||||
return promise(request({ url: url, method: 'get', params, timeout: timeout }), loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -136,9 +142,10 @@ export const post: (
|
|||
url: string,
|
||||
data?: unknown,
|
||||
params?: unknown,
|
||||
loading?: NProgress | Ref<boolean>
|
||||
) => Promise<Result<any> | any> = (url, data, params, loading) => {
|
||||
return promise(request({ url: url, method: 'post', data, params }), loading)
|
||||
loading?: NProgress | Ref<boolean>,
|
||||
timeout?: number
|
||||
) => Promise<Result<any> | any> = (url, data, params, loading, timeout) => {
|
||||
return promise(request({ url: url, method: 'post', data, params, timeout }), loading)
|
||||
}
|
||||
|
||||
/**|
|
||||
|
|
@ -153,9 +160,10 @@ export const put: (
|
|||
url: string,
|
||||
data?: unknown,
|
||||
params?: unknown,
|
||||
loading?: NProgress | Ref<boolean>
|
||||
) => Promise<Result<any>> = (url, data, params, loading) => {
|
||||
return promise(request({ url: url, method: 'put', data, params }), loading)
|
||||
loading?: NProgress | Ref<boolean>,
|
||||
timeout?: number
|
||||
) => Promise<Result<any>> = (url, data, params, loading, timeout) => {
|
||||
return promise(request({ url: url, method: 'put', data, params, timeout }), loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -169,9 +177,10 @@ export const del: (
|
|||
url: string,
|
||||
params?: unknown,
|
||||
data?: unknown,
|
||||
loading?: NProgress | Ref<boolean>
|
||||
) => Promise<Result<any>> = (url, params, data, loading) => {
|
||||
return promise(request({ url: url, method: 'delete', params, data }), loading)
|
||||
loading?: NProgress | Ref<boolean>,
|
||||
timeout?: number
|
||||
) => Promise<Result<any>> = (url, params, data, loading, timeout) => {
|
||||
return promise(request({ url: url, method: 'delete', params, data, timeout }), loading)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@
|
|||
accept="image/*"
|
||||
:on-change="onChange"
|
||||
>
|
||||
<el-button icon="Upload">上传</el-button>
|
||||
<el-button icon="Upload" :disabled="radioType !== 'custom'">上传</el-button>
|
||||
</el-upload>
|
||||
</div>
|
||||
<div class="el-upload__tip info mt-16">
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@
|
|||
<el-form-item label="AI 模型" prop="model_id">
|
||||
<template #label>
|
||||
<div class="flex-between">
|
||||
<span>AI 模型 <span class="danger">*</span></span>
|
||||
<span>AI 模型 </span>
|
||||
</div>
|
||||
</template>
|
||||
<el-select
|
||||
|
|
@ -56,6 +56,7 @@
|
|||
placeholder="请选择 AI 模型"
|
||||
class="w-full"
|
||||
popper-class="select-model"
|
||||
:clearable="true"
|
||||
>
|
||||
<el-option-group
|
||||
v-for="(value, label) in modelOptions"
|
||||
|
|
@ -338,7 +339,7 @@ const rules = reactive<FormRules<ApplicationFormType>>({
|
|||
name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }],
|
||||
model_id: [
|
||||
{
|
||||
required: true,
|
||||
required: false,
|
||||
message: '请选择模型',
|
||||
trigger: 'change'
|
||||
}
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ defineExpose({ open })
|
|||
padding: 0 !important;
|
||||
}
|
||||
.dialog-max-height {
|
||||
height: calc(100vh - 180px);
|
||||
height: 550px;
|
||||
}
|
||||
.custom-slider {
|
||||
.el-input-number.is-without-controls .el-input__wrapper {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,96 @@
|
|||
<template>
|
||||
<el-dialog
|
||||
title="设置"
|
||||
v-model="dialogVisible"
|
||||
:close-on-click-modal="false"
|
||||
:close-on-press-escape="false"
|
||||
:destroy-on-close="true"
|
||||
width="400"
|
||||
>
|
||||
<el-form
|
||||
label-position="top"
|
||||
ref="webFormRef"
|
||||
:rules="rules"
|
||||
:model="form"
|
||||
require-asterisk-position="right"
|
||||
>
|
||||
<el-form-item>
|
||||
<template #label>
|
||||
<div class="flex align-center">
|
||||
<span class="mr-4">命中处理方式</span>
|
||||
<el-tooltip
|
||||
effect="dark"
|
||||
content="用户提问时,命中文档下的分段时按照设置的方式进行处理。"
|
||||
placement="right"
|
||||
>
|
||||
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
|
||||
</el-tooltip>
|
||||
</div>
|
||||
</template>
|
||||
<el-radio-group v-model="form.hit_handling_method">
|
||||
<template v-for="(value, key) of hitHandlingMethod" :key="key">
|
||||
<el-radio :value="key">{{ value }}</el-radio>
|
||||
</template>
|
||||
</el-radio-group>
|
||||
</el-form-item>
|
||||
</el-form>
|
||||
<template #footer>
|
||||
<span class="dialog-footer">
|
||||
<el-button @click.prevent="dialogVisible = false"> 取消 </el-button>
|
||||
<el-button type="primary" @click="submit(webFormRef)" :loading="loading"> 确定 </el-button>
|
||||
</span>
|
||||
</template>
|
||||
</el-dialog>
|
||||
</template>
|
||||
<script setup lang="ts">
|
||||
import { ref, reactive, watch } from 'vue'
|
||||
import { useRoute } from 'vue-router'
|
||||
import type { FormInstance, FormRules } from 'element-plus'
|
||||
import documentApi from '@/api/document'
|
||||
import { MsgSuccess } from '@/utils/message'
|
||||
import { hitHandlingMethod } from '../utils'
|
||||
|
||||
const route = useRoute()
|
||||
const {
|
||||
params: { id }
|
||||
} = route as any
|
||||
|
||||
const emit = defineEmits(['refresh'])
|
||||
const webFormRef = ref()
|
||||
const loading = ref<boolean>(false)
|
||||
const documentList = ref<Array<string>>([])
|
||||
const form = ref<any>({
|
||||
hit_handling_method: 'optimization'
|
||||
})
|
||||
|
||||
const rules = reactive({
|
||||
source_url: [{ required: true, message: '请输入文档地址', trigger: 'blur' }]
|
||||
})
|
||||
|
||||
const dialogVisible = ref<boolean>(false)
|
||||
|
||||
const open = (list: Array<string>) => {
|
||||
documentList.value = list
|
||||
dialogVisible.value = true
|
||||
}
|
||||
|
||||
const submit = async (formEl: FormInstance | undefined) => {
|
||||
if (!formEl) return
|
||||
await formEl.validate((valid, fields) => {
|
||||
if (valid) {
|
||||
const obj = {
|
||||
hit_handling_method: form.value.hit_handling_method,
|
||||
id_list: documentList.value
|
||||
}
|
||||
documentApi.batchEditHitHandling(id, obj, loading).then((res: any) => {
|
||||
MsgSuccess('设置成功')
|
||||
emit('refresh')
|
||||
dialogVisible.value = false
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
defineExpose({ open })
|
||||
</script>
|
||||
<style lang="scss" scoped></style>
|
||||
|
|
@ -21,10 +21,13 @@
|
|||
>同步文档</el-button
|
||||
>
|
||||
<el-button @click="openDatasetDialog()" :disabled="multipleSelection.length === 0"
|
||||
>批量迁移</el-button
|
||||
>迁移</el-button
|
||||
>
|
||||
<el-button @click="openBatchEditDocument" :disabled="multipleSelection.length === 0"
|
||||
>设置</el-button
|
||||
>
|
||||
<el-button @click="deleteMulDocument" :disabled="multipleSelection.length === 0"
|
||||
>批量删除</el-button
|
||||
>删除</el-button
|
||||
>
|
||||
</div>
|
||||
|
||||
|
|
@ -212,6 +215,10 @@
|
|||
</div>
|
||||
<ImportDocumentDialog ref="ImportDocumentDialogRef" :title="title" @refresh="refresh" />
|
||||
<SyncWebDialog ref="SyncWebDialogRef" @refresh="refresh" />
|
||||
<BatchEditDocumentDialog
|
||||
ref="batchEditDocumentDialogRef"
|
||||
@refresh="refresh"
|
||||
></BatchEditDocumentDialog>
|
||||
<!-- 选择知识库 -->
|
||||
<SelectDatasetDialog ref="SelectDatasetDialogRef" @refresh="refresh" />
|
||||
</div>
|
||||
|
|
@ -225,6 +232,7 @@ import documentApi from '@/api/document'
|
|||
import ImportDocumentDialog from './component/ImportDocumentDialog.vue'
|
||||
import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue'
|
||||
import SelectDatasetDialog from './component/SelectDatasetDialog.vue'
|
||||
import BatchEditDocumentDialog from './component/BatchEditDocumentDialog.vue'
|
||||
import { numberFormat } from '@/utils/utils'
|
||||
import { datetimeFormat } from '@/utils/time'
|
||||
import { hitHandlingMethod } from './utils'
|
||||
|
|
@ -257,7 +265,7 @@ onBeforeRouteLeave((to: any, from: any) => {
|
|||
})
|
||||
const beforePagination = computed(() => common.paginationConfig[storeKey])
|
||||
const beforeSearch = computed(() => common.search[storeKey])
|
||||
|
||||
const batchEditDocumentDialogRef = ref<InstanceType<typeof BatchEditDocumentDialog>>()
|
||||
const SyncWebDialogRef = ref()
|
||||
const loading = ref(false)
|
||||
let interval: any
|
||||
|
|
@ -317,6 +325,13 @@ const handleSelectionChange = (val: any[]) => {
|
|||
multipleSelection.value = val
|
||||
}
|
||||
|
||||
function openBatchEditDocument() {
|
||||
const arr: string[] = multipleSelection.value.map((v) => v.id)
|
||||
if (batchEditDocumentDialogRef) {
|
||||
batchEditDocumentDialogRef?.value?.open(arr)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 初始化轮询
|
||||
*/
|
||||
|
|
@ -356,9 +371,9 @@ function refreshDocument(row: any) {
|
|||
.catch(() => {})
|
||||
}
|
||||
} else {
|
||||
// documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => {
|
||||
// getList()
|
||||
// })
|
||||
documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => {
|
||||
getList()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<template>
|
||||
<el-drawer v-model="visible" size="60%" @close="closeHandle" class="chat-record-drawer">
|
||||
<template #header>
|
||||
<h4>{{ currentAbstract }}</h4>
|
||||
<h4 class="single-line">{{ currentAbstract }}</h4>
|
||||
</template>
|
||||
<div
|
||||
v-loading="paginationConfig.current_page === 1 && loading"
|
||||
|
|
@ -120,6 +120,11 @@ defineExpose({
|
|||
})
|
||||
</script>
|
||||
<style lang="scss">
|
||||
.single-line {
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
.chat-record-drawer {
|
||||
.el-drawer__body {
|
||||
background: var(--app-layout-bg-color);
|
||||
|
|
|
|||
Loading…
Reference in New Issue