This commit is contained in:
tongque 2024-05-01 11:51:23 +08:00
commit 93a959b3ef
31 changed files with 479 additions and 167 deletions

View File

@ -54,7 +54,7 @@ class IChatStep(IBaseChatPipelineStep):
message_list = serializers.ListField(required=True, child=MessageField(required=True),
error_messages=ErrMessage.list("对话列表"))
# 大语言模型
chat_model = ModelField(error_messages=ErrMessage.list("大语言模型"))
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型"))
# 段落列表
paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表"))
# 对话id

View File

@ -59,8 +59,12 @@ def event_content(response,
# 获取token
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text)
try:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(all_text)
except Exception as e:
request_token = 0
response_token = 0
else:
request_token = 0
response_token = 0
@ -126,6 +130,26 @@ class BaseChatStep(IChatStep):
result.append({'role': 'ai', 'content': answer_text})
return result
@staticmethod
def get_stream_result(message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessageChunk(content=paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
return iter(directly_return_chunk_list), False
elif len(paragraph_list) == 0 and no_references_setting.get(
'status') == 'designated_answer':
return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False
if chat_model is None:
return iter([AIMessageChunk('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。')]), False
else:
return chat_model.stream(message_list), True
def execute_stream(self, message_list: List[BaseMessage],
chat_id,
problem_text,
@ -136,29 +160,8 @@ class BaseChatStep(IChatStep):
padding_problem_text: str = None,
client_id=None, client_type=None,
no_references_setting=None):
is_ai_chat = False
# 调用模型
if chat_model is None:
chat_result = iter(
[AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list])
else:
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))])
else:
if paragraph_list is not None and len(paragraph_list) > 0:
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
no_references_setting)
chat_record_id = uuid.uuid1()
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
@ -169,6 +172,27 @@ class BaseChatStep(IChatStep):
r['Cache-Control'] = 'no-cache'
return r
@staticmethod
def get_block_result(message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
no_references_setting=None):
if paragraph_list is None:
paragraph_list = []
directly_return_chunk_list = [AIMessage(content=paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
return directly_return_chunk_list[0], False
elif len(paragraph_list) == 0 and no_references_setting.get(
'status') == 'designated_answer':
return AIMessage(no_references_setting.get('value')), False
if chat_model is None:
return AIMessage('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。'), False
else:
return chat_model.invoke(message_list), True
def execute_block(self, message_list: List[BaseMessage],
chat_id,
problem_text,
@ -178,28 +202,8 @@ class BaseChatStep(IChatStep):
manage: PipelineManage = None,
padding_problem_text: str = None,
client_id=None, client_type=None, no_references_setting=None):
is_ai_chat = False
# 调用模型
if chat_model is None:
chat_result = AIMessage(
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
else:
if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get(
'status') == 'designated_answer':
chat_result = AIMessage(content=no_references_setting.get('value'))
else:
if paragraph_list is not None and len(paragraph_list) > 0:
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, no_references_setting)
chat_record_id = uuid.uuid1()
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)

View File

@ -28,7 +28,7 @@ class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
padding_problem_text: str = None,
no_references_setting=None,
**kwargs) -> List[BaseMessage]:
prompt = prompt if no_references_setting.get('status') == 'designated_answer' else no_references_setting.get(
prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get(
'value')
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
start_index = len(history_chat_record) - dialogue_number

View File

@ -28,7 +28,7 @@ class IResetProblemStep(IBaseChatPipelineStep):
history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True),
error_messages=ErrMessage.list("历史对答"))
# 大语言模型
chat_model = ModelField(error_messages=ErrMessage.base("大语言模型"))
chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型"))
def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]:
return self.InstanceSerializer

View File

@ -22,6 +22,10 @@ prompt = (
class BaseResetProblemStep(IResetProblemStep):
def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None,
**kwargs) -> str:
if chat_model is None:
self.context['message_tokens'] = 0
self.context['answer_tokens'] = 0
return problem_text
start_index = len(history_chat_record) - 3
history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()]
for index in
@ -35,8 +39,14 @@ class BaseResetProblemStep(IResetProblemStep):
response.content.index('<data>') + 6:response.content.index('</data>')]
if padding_problem_data is not None and len(padding_problem_data.strip()) > 0:
padding_problem = padding_problem_data
self.context['message_tokens'] = chat_model.get_num_tokens_from_messages(message_list)
self.context['answer_tokens'] = chat_model.get_num_tokens(padding_problem)
try:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(padding_problem)
except Exception as e:
request_token = 0
response_token = 0
self.context['message_tokens'] = request_token
self.context['answer_tokens'] = response_token
return padding_problem
def get_details(self, manage, **kwargs):

View File

@ -0,0 +1,23 @@
# Generated by Django 4.1.13 on 2024-04-29 13:33
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0004_applicationaccesstoken_show_source'),
]
operations = [
migrations.AlterField(
model_name='chat',
name='abstract',
field=models.CharField(max_length=1024, verbose_name='摘要'),
),
migrations.AlterField(
model_name='chatrecord',
name='answer_text',
field=models.CharField(max_length=40960, verbose_name='答案'),
),
]

View File

@ -73,7 +73,7 @@ class ApplicationDatasetMapping(AppModelMixin):
class Chat(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id")
application = models.ForeignKey(Application, on_delete=models.CASCADE)
abstract = models.CharField(max_length=256, verbose_name="摘要")
abstract = models.CharField(max_length=1024, verbose_name="摘要")
client_id = models.UUIDField(verbose_name="客户端id", default=None, null=True)
class Meta:
@ -96,7 +96,7 @@ class ChatRecord(AppModelMixin):
vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices,
default=VoteChoices.UN_VOTE)
problem_text = models.CharField(max_length=1024, verbose_name="问题")
answer_text = models.CharField(max_length=4096, verbose_name="答案")
answer_text = models.CharField(max_length=40960, verbose_name="答案")
message_tokens = models.IntegerField(verbose_name="请求token数量", default=0)
answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0)
const = models.IntegerField(verbose_name="总费用", default=0)

View File

@ -47,7 +47,8 @@ chat_cache = cache.caches['chat_cache']
class ModelDatasetAssociation(serializers.Serializer):
user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id"))
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("模型id"))
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid(
"知识库id")),
@ -57,8 +58,9 @@ class ModelDatasetAssociation(serializers.Serializer):
super().is_valid(raise_exception=True)
model_id = self.data.get('model_id')
user_id = self.data.get('user_id')
if not QuerySet(Model).filter(id=model_id).exists():
raise AppApiException(500, f'模型不存在【{model_id}')
if model_id is not None and len(model_id) > 0:
if not QuerySet(Model).filter(id=model_id).exists():
raise AppApiException(500, f'模型不存在【{model_id}')
dataset_id_list = list(set(self.data.get('dataset_id_list')))
exist_dataset_id_list = [str(dataset.id) for dataset in
QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)]
@ -109,7 +111,8 @@ class ApplicationSerializer(serializers.Serializer):
desc = serializers.CharField(required=False, allow_null=True, allow_blank=True,
max_length=256, min_length=1,
error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
error_messages=ErrMessage.char("开场白"))
@ -254,7 +257,8 @@ class ApplicationSerializer(serializers.Serializer):
error_messages=ErrMessage.char("应用名称"))
desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True,
error_messages=ErrMessage.char("应用描述"))
model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型"))
model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True,
error_messages=ErrMessage.char("模型"))
multiple_rounds_dialogue = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean("多轮会话"))
prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024,
@ -494,22 +498,21 @@ class ApplicationSerializer(serializers.Serializer):
application_id = self.data.get("application_id")
application = QuerySet(Application).get(id=application_id)
model = QuerySet(Model).filter(
id=instance.get('model_id') if 'model_id' in instance else application.model_id,
user_id=application.user_id).first()
if model is None:
raise AppApiException(500, "模型不存在")
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
application.model_id = None
else:
model = QuerySet(Model).filter(
id=instance.get('model_id'),
user_id=application.user_id).first()
if model is None:
raise AppApiException(500, "模型不存在")
update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status',
'dataset_setting', 'model_setting', 'problem_optimization',
'api_key_is_active', 'icon']
for update_key in update_keys:
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'multiple_rounds_dialogue':
application.__setattr__('dialogue_number',
0 if not instance.get(update_key) else ModelProvideConstants[
model.provider].value.get_dialogue_number())
application.__setattr__('dialogue_number', 0 if not instance.get(update_key) else 3)
else:
application.__setattr__(update_key, instance.get(update_key))
application.save()

View File

@ -27,7 +27,7 @@ from application.models.api_key_model import ApplicationPublicAccessClient, Appl
from common.constants.authentication_type import AuthenticationType
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
from common.util.field_message import ErrMessage
from common.util.rsa_util import decrypt
from common.util.rsa_util import rsa_long_decrypt
from common.util.split_model import flat_map
from dataset.models import Paragraph, Document
from setting.models import Model, Status
@ -138,7 +138,7 @@ def get_post_handler(chat_info: ChatInfo):
class ChatMessageSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id"))
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"))
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"), max_length=1024)
stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答"))
re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答"))
application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id"))
@ -167,9 +167,11 @@ class ChatMessageSerializer(serializers.Serializer):
chat_cache.set(chat_id,
chat_info, timeout=60 * 30)
model = chat_info.application.model
if model is None:
return chat_info
model = QuerySet(Model).filter(id=model.id).first()
if model is None:
raise AppApiException(500, "模型不存在")
return chat_info
if model.status == Status.ERROR:
raise AppApiException(500, "当前模型不可用")
if model.status == Status.DOWNLOAD:
@ -223,7 +225,7 @@ class ChatMessageSerializer(serializers.Serializer):
# 对话模型
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
rsa_long_decrypt(model.credential)),
streaming=True)
# 数据集id列表
dataset_id_list = [str(row.dataset_id) for row in

View File

@ -35,7 +35,7 @@ from common.util.common import post
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import decrypt
from common.util.rsa_util import rsa_long_decrypt
from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from setting.models import Model
@ -195,7 +195,8 @@ class ChatSerializers(serializers.Serializer):
if model is not None:
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
rsa_long_decrypt(
model.credential)),
streaming=True)
chat_id = str(uuid.uuid1())
@ -213,7 +214,8 @@ class ChatSerializers(serializers.Serializer):
id = serializers.UUIDField(required=False, allow_null=True,
error_messages=ErrMessage.uuid("应用id"))
model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id"))
model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
error_messages=ErrMessage.uuid("模型id"))
multiple_rounds_dialogue = serializers.BooleanField(required=True,
error_messages=ErrMessage.boolean("多轮会话"))
@ -246,14 +248,18 @@ class ChatSerializers(serializers.Serializer):
def open(self):
user_id = self.is_valid(raise_exception=True)
chat_id = str(uuid.uuid1())
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
if model is None:
raise AppApiException(500, "模型不存在")
model_id = self.data.get('model_id')
if model_id is not None and len(model_id) > 0:
model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first()
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(
model.credential)),
streaming=True)
else:
model = None
chat_model = None
dataset_id_list = self.data.get('dataset_id_list')
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
application = Application(id=None, dialogue_number=3, model=model,
dataset_setting=self.data.get('dataset_setting'),
model_setting=self.data.get('model_setting'),

View File

@ -62,18 +62,6 @@ def get_key_pair_by_sql():
return system_setting.meta
# def get_key_pair():
# if not os.path.exists("/opt/maxkb/conf/receiver.pem"):
# kv = generate()
# private_file_out = open("/opt/maxkb/conf/private.pem", "wb")
# private_file_out.write(kv.get('value'))
# private_file_out.close()
# receiver_file_out = open("/opt/maxkb/conf/receiver.pem", "wb")
# receiver_file_out.write(kv.get('key'))
# receiver_file_out.close()
# return {'key': open("/opt/maxkb/conf/receiver.pem").read(), 'value': open("/opt/maxkb/conf/private.pem").read()}
def encrypt(msg, public_key: str | None = None):
"""
加密
@ -100,3 +88,53 @@ def decrypt(msg, pri_key: str | None = None):
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
return decrypt_data.decode("utf-8")
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
"""
超长文本加密
:param message: 需要加密的字符串
:param public_key 公钥
:param length: 1024bit的证书用100 2048bit的证书用 200
:return: 加密后的数据
"""
# 读取公钥
if public_key is None:
public_key = get_key_pair().get('key')
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
passphrase=secret_code))
# 处理Plaintext is too long. 分段加密
if len(message) <= length:
# 对编码的数据进行加密并通过base64进行编码
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
else:
rsa_text = []
# 对编码后的数据进行切片,原因:加密长度不能过长
for i in range(0, len(message), length):
cont = message[i:i + length]
# 对切片后的数据进行加密并新增到text后面
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
# 加密完进行拼接
cipher_text = b''.join(rsa_text)
# base64进行编码
result = base64.b64encode(cipher_text)
return result.decode()
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
"""
超长文本解密默认不加密
:param message: 需要解密的数据
:param pri_key: 秘钥
:param length : 1024bit的证书用1282048bit证书用256位
:return: 解密后的数据
"""
if pri_key is None:
pri_key = get_key_pair().get('value')
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
base64_de = base64.b64decode(message)
res = []
for i in range(0, len(base64_de), length):
res.append(cipher.decrypt(base64_de[i:i + length], 0))
return b"".join(res).decode()

View File

@ -164,6 +164,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
elif target_dataset.type == Type.base.value and dataset.type == Type.web.value:
document_list.update(dataset_id=target_dataset_id, type=Type.base,
meta={})
else:
document_list.update(dataset_id=target_dataset_id)
paragraph_list.update(dataset_id=target_dataset_id)
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
[problem_paragraph_mapping.id for problem_paragraph_mapping in problem_paragraph_mapping_list],
@ -713,6 +715,19 @@ class DocumentSerializers(ApiMixin, serializers.Serializer):
ListenerManagement.delete_embedding_by_document_list_signal.send(document_id_list)
return True
def batch_edit_hit_handling(self, instance: Dict, with_valid=True):
if with_valid:
BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True)
hit_handling_method = instance.get('hit_handling_method')
if hit_handling_method is None:
raise AppApiException(500, '命中处理方式必填')
if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return':
raise AppApiException(500, '命中处理方式必须为directly_return|optimization')
self.is_valid(raise_exception=True)
document_id_list = instance.get("id_list")
hit_handling_method = instance.get('hit_handling_method')
QuerySet(Document).filter(id__in=document_id_list).update(hit_handling_method=hit_handling_method)
class FileBufferHandle:
buffer = None

View File

@ -0,0 +1,27 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file document_api.py
@date2024/4/28 13:56
@desc:
"""
from drf_yasg import openapi
from common.mixins.api_mixin import ApiMixin
class DocumentApi(ApiMixin):
class BatchEditHitHandlingApi(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING),
title="主键id列表",
description="主键id列表"),
'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式",
description="directly_return|optimization")
}
)

View File

@ -14,6 +14,7 @@ urlpatterns = [
path('dataset/<str:dataset_id>/document', views.Document.as_view(), name='document'),
path('dataset/<str:dataset_id>/document/web', views.WebDocument.as_view()),
path('dataset/<str:dataset_id>/document/_bach', views.Document.Batch.as_view()),
path('dataset/<str:dataset_id>/document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()),
path('dataset/<str:dataset_id>/document/<int:current_page>/<int:page_size>', views.Document.Page.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>', views.Document.Operate.as_view(),
name="document_operate"),

View File

@ -19,6 +19,7 @@ from common.response import result
from common.util.common import query_params_to_single_dict
from dataset.serializers.common_serializers import BatchSerializer
from dataset.serializers.document_serializers import DocumentSerializers, DocumentWebInstanceSerializer
from dataset.swagger_api.document_api import DocumentApi
class WebDocument(APIView):
@ -71,6 +72,24 @@ class Document(APIView):
d.is_valid(raise_exception=True)
return result.success(d.list())
class BatchEditHitHandling(APIView):
authentication_classes = [TokenAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="批量修改文档命中处理方式",
operation_id="批量修改文档命中处理方式",
request_body=
DocumentApi.BatchEditHitHandlingApi.get_request_body_api(),
manual_parameters=DocumentSerializers.Create.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def put(self, request: Request, dataset_id: str):
return result.success(
DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_edit_hit_handling(request.data))
class Batch(APIView):
authentication_classes = [TokenAuth]

View File

@ -18,27 +18,30 @@ def update_embedding_search_vector(embedding, paragraph_list):
def save_keywords(apps, schema_editor):
document = apps.get_model("dataset", "Document")
embedding = apps.get_model("embedding", "Embedding")
paragraph = apps.get_model('dataset', 'Paragraph')
db_alias = schema_editor.connection.alias
document_list = document.objects.using(db_alias).all()
for document in document_list:
document.status = Status.embedding
document.save()
paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all()
embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector',
'paragraph')
embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding
in embedding_list]
child_array = sub_array(embedding_update_list, 50)
for c in child_array:
try:
embedding.objects.using(db_alias).bulk_update(c, ['search_vector'])
except Exception as e:
print(e)
document.status = Status.success
document.save()
try:
document = apps.get_model("dataset", "Document")
embedding = apps.get_model("embedding", "Embedding")
paragraph = apps.get_model('dataset', 'Paragraph')
db_alias = schema_editor.connection.alias
document_list = document.objects.using(db_alias).all()
for document in document_list:
document.status = Status.embedding
document.save()
paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all()
embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector',
'paragraph')
embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding
in embedding_list]
child_array = sub_array(embedding_update_list, 50)
for c in child_array:
try:
embedding.objects.using(db_alias).bulk_update(c, ['search_vector'])
except Exception as e:
print(e)
document.status = Status.success
document.save()
except Exception as e:
print(e)
class Migration(migrations.Migration):

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.13 on 2024-04-28 18:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('setting', '0003_model_meta_model_status'),
]
operations = [
migrations.AlterField(
model_name='model',
name='credential',
field=models.CharField(max_length=102400, verbose_name='模型认证信息'),
),
]

View File

@ -42,7 +42,7 @@ class Model(AppModelMixin):
provider = models.CharField(max_length=128, verbose_name='供应商')
credential = models.CharField(max_length=5120, verbose_name="模型认证信息")
credential = models.CharField(max_length=102400, verbose_name="模型认证信息")
meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict)

View File

@ -18,7 +18,7 @@ from rest_framework import serializers
from application.models import Application
from common.exception.app_exception import AppApiException
from common.util.field_message import ErrMessage
from common.util.rsa_util import encrypt, decrypt
from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt
from setting.models.model_management import Model, Status
from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
@ -118,7 +118,7 @@ class ModelSerializer(serializers.Serializer):
model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type,
model_name)
source_model_credential = json.loads(decrypt(model.credential))
source_model_credential = json.loads(rsa_long_decrypt(model.credential))
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
if credential is not None:
for k in source_encryption_model_credential.keys():
@ -170,7 +170,7 @@ class ModelSerializer(serializers.Serializer):
model_name = self.data.get('model_name')
model_credential_str = json.dumps(credential)
model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name,
credential=encrypt(model_credential_str),
credential=rsa_long_encrypt(model_credential_str),
provider=provider, model_type=model_type, model_name=model_name)
model.save()
if status == Status.DOWNLOAD:
@ -180,7 +180,7 @@ class ModelSerializer(serializers.Serializer):
@staticmethod
def model_to_dict(model: Model):
credential = json.loads(decrypt(model.credential))
credential = json.loads(rsa_long_decrypt(model.credential))
return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type,
'model_name': model.model_name,
'status': model.status,
@ -252,7 +252,7 @@ class ModelSerializer(serializers.Serializer):
if update_key in instance and instance.get(update_key) is not None:
if update_key == 'credential':
model_credential_str = json.dumps(credential)
model.__setattr__(update_key, encrypt(model_credential_str))
model.__setattr__(update_key, rsa_long_encrypt(model_credential_str))
else:
model.__setattr__(update_key, instance.get(update_key))
model.save()

View File

@ -73,7 +73,7 @@ const postDataset: (data: datasetData, loading?: Ref<boolean>) => Promise<Result
data,
loading
) => {
return post(`${prefix}`, data, undefined, loading)
return post(`${prefix}`, data, undefined, loading, 1000 * 60 * 5)
}
/**

View File

@ -10,7 +10,7 @@ const prefix = '/dataset'
* @param file:file,limit:number,patterns:array,with_filter:boolean
*/
const postSplitDocument: (data: any) => Promise<Result<any>> = (data) => {
return post(`${prefix}/document/split`, data)
return post(`${prefix}/document/split`, data, undefined, undefined, 1000 * 60 * 60)
}
/**
@ -80,7 +80,7 @@ const postDocument: (
data: any,
loading?: Ref<boolean>
) => Promise<Result<any>> = (dataset_id, data, loading) => {
return post(`${prefix}/${dataset_id}/document/_bach`, data, {}, loading)
return post(`${prefix}/${dataset_id}/document/_bach`, data, {}, loading, 1000 * 60 * 5)
}
/**
@ -206,6 +206,20 @@ const putMigrateMulDocument: (
)
}
/**
*
* @param dataset_id id
* @param data {id_list:[],hit_handling_method:'directly_return|optimization'}
* @param loading
* @returns
*/
const batchEditHitHandling: (
dataset_id: string,
data: any,
loading?: Ref<boolean>
) => Promise<Result<boolean>> = (dataset_id, data, loading) => {
return put(`${prefix}/${dataset_id}/document/batch_hit_handling`, data, undefined, loading)
}
export default {
postSplitDocument,
getDocument,
@ -219,5 +233,6 @@ export default {
putDocumentRefresh,
delMulSyncDocument,
postWebDocument,
putMigrateMulDocument
putMigrateMulDocument,
batchEditHitHandling
}

View File

@ -56,7 +56,7 @@ export class ChatRecordManage {
this.chat.answer_text =
this.chat.answer_text + this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('')
} else if (this.is_close) {
this.chat.answer_text = this.chat.answer_text + this.chat.buffer.join('')
this.chat.answer_text = this.chat.answer_text + this.chat.buffer.splice(0).join('')
this.chat.write_ed = true
this.write_ed = true
if (this.loading) {

View File

@ -5,6 +5,7 @@
v-model="dialogVisible"
destroy-on-close
append-to-body
align-center
>
<div class="paragraph-source-height">
<el-scrollbar>
@ -63,7 +64,7 @@
</el-dialog>
</template>
<script setup lang="ts">
import { ref, watch, nextTick } from 'vue'
import { ref, watch, onBeforeUnmount } from 'vue'
import { cloneDeep } from 'lodash'
import { arraySort } from '@/utils/utils'
const emit = defineEmits(['refresh'])
@ -86,12 +87,15 @@ const open = (data: any, id?: string) => {
detail.value.paragraph_list = arraySort(detail.value.paragraph_list, 'similarity', true)
dialogVisible.value = true
}
onBeforeUnmount(() => {
dialogVisible.value = false
})
defineExpose({ open })
</script>
<style lang="scss">
.paragraph-source {
padding: 0;
.el-dialog__header {
padding: 24px 24px 0 24px;
}
@ -102,4 +106,9 @@ defineExpose({ open })
height: calc(100vh - 260px);
}
}
@media only screen and (max-width: 768px) {
.paragraph-source {
width: 90% !important;
}
}
</style>

View File

@ -224,7 +224,7 @@ const chartOpenId = ref('')
const chatList = ref<any[]>([])
const isDisabledChart = computed(
() => !(inputValue.value.trim() && (props.appId || (props.data?.name && props.data?.model_id)))
() => !(inputValue.value.trim() && (props.appId || props.data?.name))
)
const isMdArray = (val: string) => val.match(/^-\s.*/m)
const prologueList = computed(() => {
@ -274,12 +274,7 @@ function openParagraph(row: any, id?: string) {
}
function quickProblemHandle(val: string) {
if (!props.log && !loading.value && props.data?.name && props.data?.model_id) {
// inputValue.value = val
// nextTick(() => {
// quickInputRef.value?.focus()
// })
if (!loading.value && props.data?.name) {
handleDebounceClick(val)
}
}
@ -509,16 +504,14 @@ function regenerationChart(item: chatType) {
}
function getSourceDetail(row: any) {
logApi
.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading)
.then((res) => {
const exclude_keys = ['answer_text', 'id']
Object.keys(res.data).forEach((key) => {
if (!exclude_keys.includes(key)) {
row[key] = res.data[key]
}
})
logApi.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading).then((res) => {
const exclude_keys = ['answer_text', 'id']
Object.keys(res.data).forEach((key) => {
if (!exclude_keys.includes(key)) {
row[key] = res.data[key]
}
})
})
return true
}

View File

@ -119,9 +119,15 @@ const promise: (
export const get: (
url: string,
params?: unknown,
loading?: NProgress | Ref<boolean>
) => Promise<Result<any>> = (url: string, params: unknown, loading?: NProgress | Ref<boolean>) => {
return promise(request({ url: url, method: 'get', params }), loading)
loading?: NProgress | Ref<boolean>,
timeout?: number
) => Promise<Result<any>> = (
url: string,
params: unknown,
loading?: NProgress | Ref<boolean>,
timeout?: number
) => {
return promise(request({ url: url, method: 'get', params, timeout: timeout }), loading)
}
/**
@ -136,9 +142,10 @@ export const post: (
url: string,
data?: unknown,
params?: unknown,
loading?: NProgress | Ref<boolean>
) => Promise<Result<any> | any> = (url, data, params, loading) => {
return promise(request({ url: url, method: 'post', data, params }), loading)
loading?: NProgress | Ref<boolean>,
timeout?: number
) => Promise<Result<any> | any> = (url, data, params, loading, timeout) => {
return promise(request({ url: url, method: 'post', data, params, timeout }), loading)
}
/**|
@ -153,9 +160,10 @@ export const put: (
url: string,
data?: unknown,
params?: unknown,
loading?: NProgress | Ref<boolean>
) => Promise<Result<any>> = (url, data, params, loading) => {
return promise(request({ url: url, method: 'put', data, params }), loading)
loading?: NProgress | Ref<boolean>,
timeout?: number
) => Promise<Result<any>> = (url, data, params, loading, timeout) => {
return promise(request({ url: url, method: 'put', data, params, timeout }), loading)
}
/**
@ -169,9 +177,10 @@ export const del: (
url: string,
params?: unknown,
data?: unknown,
loading?: NProgress | Ref<boolean>
) => Promise<Result<any>> = (url, params, data, loading) => {
return promise(request({ url: url, method: 'delete', params, data }), loading)
loading?: NProgress | Ref<boolean>,
timeout?: number
) => Promise<Result<any>> = (url, params, data, loading, timeout) => {
return promise(request({ url: url, method: 'delete', params, data, timeout }), loading)
}
/**

View File

@ -35,7 +35,7 @@
accept="image/*"
:on-change="onChange"
>
<el-button icon="Upload">上传</el-button>
<el-button icon="Upload" :disabled="radioType !== 'custom'">上传</el-button>
</el-upload>
</div>
<div class="el-upload__tip info mt-16">

View File

@ -48,7 +48,7 @@
<el-form-item label="AI 模型" prop="model_id">
<template #label>
<div class="flex-between">
<span>AI 模型 <span class="danger">*</span></span>
<span>AI 模型 </span>
</div>
</template>
<el-select
@ -56,6 +56,7 @@
placeholder="请选择 AI 模型"
class="w-full"
popper-class="select-model"
:clearable="true"
>
<el-option-group
v-for="(value, label) in modelOptions"
@ -338,7 +339,7 @@ const rules = reactive<FormRules<ApplicationFormType>>({
name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }],
model_id: [
{
required: true,
required: false,
message: '请选择模型',
trigger: 'change'
}

View File

@ -251,7 +251,7 @@ defineExpose({ open })
padding: 0 !important;
}
.dialog-max-height {
height: calc(100vh - 180px);
height: 550px;
}
.custom-slider {
.el-input-number.is-without-controls .el-input__wrapper {

View File

@ -0,0 +1,96 @@
<template>
<el-dialog
title="设置"
v-model="dialogVisible"
:close-on-click-modal="false"
:close-on-press-escape="false"
:destroy-on-close="true"
width="400"
>
<el-form
label-position="top"
ref="webFormRef"
:rules="rules"
:model="form"
require-asterisk-position="right"
>
<el-form-item>
<template #label>
<div class="flex align-center">
<span class="mr-4">命中处理方式</span>
<el-tooltip
effect="dark"
content="用户提问时,命中文档下的分段时按照设置的方式进行处理。"
placement="right"
>
<AppIcon iconName="app-warning" class="app-warning-icon"></AppIcon>
</el-tooltip>
</div>
</template>
<el-radio-group v-model="form.hit_handling_method">
<template v-for="(value, key) of hitHandlingMethod" :key="key">
<el-radio :value="key">{{ value }}</el-radio>
</template>
</el-radio-group>
</el-form-item>
</el-form>
<template #footer>
<span class="dialog-footer">
<el-button @click.prevent="dialogVisible = false"> 取消 </el-button>
<el-button type="primary" @click="submit(webFormRef)" :loading="loading"> 确定 </el-button>
</span>
</template>
</el-dialog>
</template>
<script setup lang="ts">
import { ref, reactive, watch } from 'vue'
import { useRoute } from 'vue-router'
import type { FormInstance, FormRules } from 'element-plus'
import documentApi from '@/api/document'
import { MsgSuccess } from '@/utils/message'
import { hitHandlingMethod } from '../utils'
const route = useRoute()
const {
params: { id }
} = route as any
const emit = defineEmits(['refresh'])
const webFormRef = ref()
const loading = ref<boolean>(false)
const documentList = ref<Array<string>>([])
const form = ref<any>({
hit_handling_method: 'optimization'
})
const rules = reactive({
source_url: [{ required: true, message: '请输入文档地址', trigger: 'blur' }]
})
const dialogVisible = ref<boolean>(false)
const open = (list: Array<string>) => {
documentList.value = list
dialogVisible.value = true
}
const submit = async (formEl: FormInstance | undefined) => {
if (!formEl) return
await formEl.validate((valid, fields) => {
if (valid) {
const obj = {
hit_handling_method: form.value.hit_handling_method,
id_list: documentList.value
}
documentApi.batchEditHitHandling(id, obj, loading).then((res: any) => {
MsgSuccess('设置成功')
emit('refresh')
dialogVisible.value = false
})
}
})
}
defineExpose({ open })
</script>
<style lang="scss" scoped></style>

View File

@ -21,10 +21,13 @@
>同步文档</el-button
>
<el-button @click="openDatasetDialog()" :disabled="multipleSelection.length === 0"
>批量迁移</el-button
>迁移</el-button
>
<el-button @click="openBatchEditDocument" :disabled="multipleSelection.length === 0"
>设置</el-button
>
<el-button @click="deleteMulDocument" :disabled="multipleSelection.length === 0"
>批量删除</el-button
>删除</el-button
>
</div>
@ -212,6 +215,10 @@
</div>
<ImportDocumentDialog ref="ImportDocumentDialogRef" :title="title" @refresh="refresh" />
<SyncWebDialog ref="SyncWebDialogRef" @refresh="refresh" />
<BatchEditDocumentDialog
ref="batchEditDocumentDialogRef"
@refresh="refresh"
></BatchEditDocumentDialog>
<!-- 选择知识库 -->
<SelectDatasetDialog ref="SelectDatasetDialogRef" @refresh="refresh" />
</div>
@ -225,6 +232,7 @@ import documentApi from '@/api/document'
import ImportDocumentDialog from './component/ImportDocumentDialog.vue'
import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue'
import SelectDatasetDialog from './component/SelectDatasetDialog.vue'
import BatchEditDocumentDialog from './component/BatchEditDocumentDialog.vue'
import { numberFormat } from '@/utils/utils'
import { datetimeFormat } from '@/utils/time'
import { hitHandlingMethod } from './utils'
@ -257,7 +265,7 @@ onBeforeRouteLeave((to: any, from: any) => {
})
const beforePagination = computed(() => common.paginationConfig[storeKey])
const beforeSearch = computed(() => common.search[storeKey])
const batchEditDocumentDialogRef = ref<InstanceType<typeof BatchEditDocumentDialog>>()
const SyncWebDialogRef = ref()
const loading = ref(false)
let interval: any
@ -317,6 +325,13 @@ const handleSelectionChange = (val: any[]) => {
multipleSelection.value = val
}
function openBatchEditDocument() {
const arr: string[] = multipleSelection.value.map((v) => v.id)
if (batchEditDocumentDialogRef) {
batchEditDocumentDialogRef?.value?.open(arr)
}
}
/**
* 初始化轮询
*/
@ -356,9 +371,9 @@ function refreshDocument(row: any) {
.catch(() => {})
}
} else {
// documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => {
// getList()
// })
documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => {
getList()
})
}
}

View File

@ -1,7 +1,7 @@
<template>
<el-drawer v-model="visible" size="60%" @close="closeHandle" class="chat-record-drawer">
<template #header>
<h4>{{ currentAbstract }}</h4>
<h4 class="single-line">{{ currentAbstract }}</h4>
</template>
<div
v-loading="paginationConfig.current_page === 1 && loading"
@ -120,6 +120,11 @@ defineExpose({
})
</script>
<style lang="scss">
.single-line {
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
}
.chat-record-drawer {
.el-drawer__body {
background: var(--app-layout-bg-color);