diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 534b9d409..8fbac34c9 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -54,7 +54,7 @@ class IChatStep(IBaseChatPipelineStep): message_list = serializers.ListField(required=True, child=MessageField(required=True), error_messages=ErrMessage.list("对话列表")) # 大语言模型 - chat_model = ModelField(error_messages=ErrMessage.list("大语言模型")) + chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型")) # 段落列表 paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表")) # 对话id diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 0919abbe2..276d01947 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -59,8 +59,12 @@ def event_content(response, # 获取token if is_ai_chat: - request_token = chat_model.get_num_tokens_from_messages(message_list) - response_token = chat_model.get_num_tokens(all_text) + try: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(all_text) + except Exception as e: + request_token = 0 + response_token = 0 else: request_token = 0 response_token = 0 @@ -126,6 +130,26 @@ class BaseChatStep(IChatStep): result.append({'role': 'ai', 'content': answer_text}) return result + @staticmethod + def get_stream_result(message_list: List[BaseMessage], + chat_model: BaseChatModel = None, + paragraph_list=None, + no_references_setting=None): + if paragraph_list is None: + paragraph_list = [] + directly_return_chunk_list = [AIMessageChunk(content=paragraph.content) + for paragraph in paragraph_list if + paragraph.hit_handling_method == 'directly_return'] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + return iter(directly_return_chunk_list), False + elif len(paragraph_list) == 0 and no_references_setting.get( + 'status') == 'designated_answer': + return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False + if chat_model is None: + return iter([AIMessageChunk('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。')]), False + else: + return chat_model.stream(message_list), True + def execute_stream(self, message_list: List[BaseMessage], chat_id, problem_text, @@ -136,29 +160,8 @@ class BaseChatStep(IChatStep): padding_problem_text: str = None, client_id=None, client_type=None, no_references_setting=None): - is_ai_chat = False - # 调用模型 - if chat_model is None: - chat_result = iter( - [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list]) - else: - if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get( - 'status') == 'designated_answer': - chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))]) - else: - if paragraph_list is not None and len(paragraph_list) > 0: - directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) - for paragraph in paragraph_list if - paragraph.hit_handling_method == 'directly_return'] - if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: - chat_result = iter(directly_return_chunk_list) - else: - chat_result = chat_model.stream(message_list) - is_ai_chat = True - else: - chat_result = chat_model.stream(message_list) - is_ai_chat = True - + chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, + no_references_setting) chat_record_id = uuid.uuid1() r = StreamingHttpResponse( streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, @@ -169,6 +172,27 @@ class BaseChatStep(IChatStep): r['Cache-Control'] = 'no-cache' return r + @staticmethod + def get_block_result(message_list: List[BaseMessage], + chat_model: BaseChatModel = None, + paragraph_list=None, + no_references_setting=None): + if paragraph_list is None: + paragraph_list = [] + + directly_return_chunk_list = [AIMessage(content=paragraph.content) + for paragraph in paragraph_list if + paragraph.hit_handling_method == 'directly_return'] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + return directly_return_chunk_list[0], False + elif len(paragraph_list) == 0 and no_references_setting.get( + 'status') == 'designated_answer': + return AIMessage(no_references_setting.get('value')), False + if chat_model is None: + return AIMessage('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。'), False + else: + return chat_model.invoke(message_list), True + def execute_block(self, message_list: List[BaseMessage], chat_id, problem_text, @@ -178,28 +202,8 @@ class BaseChatStep(IChatStep): manage: PipelineManage = None, padding_problem_text: str = None, client_id=None, client_type=None, no_references_setting=None): - is_ai_chat = False # 调用模型 - if chat_model is None: - chat_result = AIMessage( - content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list])) - else: - if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get( - 'status') == 'designated_answer': - chat_result = AIMessage(content=no_references_setting.get('value')) - else: - if paragraph_list is not None and len(paragraph_list) > 0: - directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) - for paragraph in paragraph_list if - paragraph.hit_handling_method == 'directly_return'] - if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: - chat_result = iter(directly_return_chunk_list) - else: - chat_result = chat_model.invoke(message_list) - is_ai_chat = True - else: - chat_result = chat_model.invoke(message_list) - is_ai_chat = True + chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, no_references_setting) chat_record_id = uuid.uuid1() if is_ai_chat: request_token = chat_model.get_num_tokens_from_messages(message_list) diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py index bff3aa2b2..6664a286c 100644 --- a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py +++ b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py @@ -28,7 +28,7 @@ class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep): padding_problem_text: str = None, no_references_setting=None, **kwargs) -> List[BaseMessage]: - prompt = prompt if no_references_setting.get('status') == 'designated_answer' else no_references_setting.get( + prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get( 'value') exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text start_index = len(history_chat_record) - dialogue_number diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py index 930fd2482..ce30d96af 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py @@ -28,7 +28,7 @@ class IResetProblemStep(IBaseChatPipelineStep): history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), error_messages=ErrMessage.list("历史对答")) # 大语言模型 - chat_model = ModelField(error_messages=ErrMessage.base("大语言模型")) + chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型")) def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: return self.InstanceSerializer diff --git a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py index 2386be4fe..c0595d590 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py @@ -22,6 +22,10 @@ prompt = ( class BaseResetProblemStep(IResetProblemStep): def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None, **kwargs) -> str: + if chat_model is None: + self.context['message_tokens'] = 0 + self.context['answer_tokens'] = 0 + return problem_text start_index = len(history_chat_record) - 3 history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] for index in @@ -35,8 +39,14 @@ class BaseResetProblemStep(IResetProblemStep): response.content.index('') + 6:response.content.index('')] if padding_problem_data is not None and len(padding_problem_data.strip()) > 0: padding_problem = padding_problem_data - self.context['message_tokens'] = chat_model.get_num_tokens_from_messages(message_list) - self.context['answer_tokens'] = chat_model.get_num_tokens(padding_problem) + try: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(padding_problem) + except Exception as e: + request_token = 0 + response_token = 0 + self.context['message_tokens'] = request_token + self.context['answer_tokens'] = response_token return padding_problem def get_details(self, manage, **kwargs): diff --git a/apps/application/migrations/0005_alter_chat_abstract_alter_chatrecord_answer_text.py b/apps/application/migrations/0005_alter_chat_abstract_alter_chatrecord_answer_text.py new file mode 100644 index 000000000..0643a39ce --- /dev/null +++ b/apps/application/migrations/0005_alter_chat_abstract_alter_chatrecord_answer_text.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.13 on 2024-04-29 13:33 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0004_applicationaccesstoken_show_source'), + ] + + operations = [ + migrations.AlterField( + model_name='chat', + name='abstract', + field=models.CharField(max_length=1024, verbose_name='摘要'), + ), + migrations.AlterField( + model_name='chatrecord', + name='answer_text', + field=models.CharField(max_length=40960, verbose_name='答案'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 774e1bdc2..6c77937bd 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -73,7 +73,7 @@ class ApplicationDatasetMapping(AppModelMixin): class Chat(AppModelMixin): id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") application = models.ForeignKey(Application, on_delete=models.CASCADE) - abstract = models.CharField(max_length=256, verbose_name="摘要") + abstract = models.CharField(max_length=1024, verbose_name="摘要") client_id = models.UUIDField(verbose_name="客户端id", default=None, null=True) class Meta: @@ -96,7 +96,7 @@ class ChatRecord(AppModelMixin): vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices, default=VoteChoices.UN_VOTE) problem_text = models.CharField(max_length=1024, verbose_name="问题") - answer_text = models.CharField(max_length=4096, verbose_name="答案") + answer_text = models.CharField(max_length=40960, verbose_name="答案") message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) const = models.IntegerField(verbose_name="总费用", default=0) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 3307d873a..e0abb772d 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -47,7 +47,8 @@ chat_cache = cache.caches['chat_cache'] class ModelDatasetAssociation(serializers.Serializer): user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) - model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("模型id")) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True, error_messages=ErrMessage.uuid( "知识库id")), @@ -57,8 +58,9 @@ class ModelDatasetAssociation(serializers.Serializer): super().is_valid(raise_exception=True) model_id = self.data.get('model_id') user_id = self.data.get('user_id') - if not QuerySet(Model).filter(id=model_id).exists(): - raise AppApiException(500, f'模型不存在【{model_id}】') + if model_id is not None and len(model_id) > 0: + if not QuerySet(Model).filter(id=model_id).exists(): + raise AppApiException(500, f'模型不存在【{model_id}】') dataset_id_list = list(set(self.data.get('dataset_id_list'))) exist_dataset_id_list = [str(dataset.id) for dataset in QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)] @@ -109,7 +111,8 @@ class ApplicationSerializer(serializers.Serializer): desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=256, min_length=1, error_messages=ErrMessage.char("应用描述")) - model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("模型")) multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话")) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024, error_messages=ErrMessage.char("开场白")) @@ -254,7 +257,8 @@ class ApplicationSerializer(serializers.Serializer): error_messages=ErrMessage.char("应用名称")) desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True, error_messages=ErrMessage.char("应用描述")) - model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型")) + model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("模型")) multiple_rounds_dialogue = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("多轮会话")) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024, @@ -494,22 +498,21 @@ class ApplicationSerializer(serializers.Serializer): application_id = self.data.get("application_id") application = QuerySet(Application).get(id=application_id) - - model = QuerySet(Model).filter( - id=instance.get('model_id') if 'model_id' in instance else application.model_id, - user_id=application.user_id).first() - if model is None: - raise AppApiException(500, "模型不存在") - + if instance.get('model_id') is None or len(instance.get('model_id')) == 0: + application.model_id = None + else: + model = QuerySet(Model).filter( + id=instance.get('model_id'), + user_id=application.user_id).first() + if model is None: + raise AppApiException(500, "模型不存在") update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', 'dataset_setting', 'model_setting', 'problem_optimization', 'api_key_is_active', 'icon'] for update_key in update_keys: if update_key in instance and instance.get(update_key) is not None: if update_key == 'multiple_rounds_dialogue': - application.__setattr__('dialogue_number', - 0 if not instance.get(update_key) else ModelProvideConstants[ - model.provider].value.get_dialogue_number()) + application.__setattr__('dialogue_number', 0 if not instance.get(update_key) else 3) else: application.__setattr__(update_key, instance.get(update_key)) application.save() diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 47b905a77..f8c80a865 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -27,7 +27,7 @@ from application.models.api_key_model import ApplicationPublicAccessClient, Appl from common.constants.authentication_type import AuthenticationType from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed from common.util.field_message import ErrMessage -from common.util.rsa_util import decrypt +from common.util.rsa_util import rsa_long_decrypt from common.util.split_model import flat_map from dataset.models import Paragraph, Document from setting.models import Model, Status @@ -138,7 +138,7 @@ def get_post_handler(chat_info: ChatInfo): class ChatMessageSerializer(serializers.Serializer): chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id")) - message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题")) + message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"), max_length=1024) stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答")) re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答")) application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) @@ -167,9 +167,11 @@ class ChatMessageSerializer(serializers.Serializer): chat_cache.set(chat_id, chat_info, timeout=60 * 30) model = chat_info.application.model + if model is None: + return chat_info model = QuerySet(Model).filter(id=model.id).first() if model is None: - raise AppApiException(500, "模型不存在") + return chat_info if model.status == Status.ERROR: raise AppApiException(500, "当前模型不可用") if model.status == Status.DOWNLOAD: @@ -223,7 +225,7 @@ class ChatMessageSerializer(serializers.Serializer): # 对话模型 chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, json.loads( - decrypt(model.credential)), + rsa_long_decrypt(model.credential)), streaming=True) # 数据集id列表 dataset_id_list = [str(row.dataset_id) for row in diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 5f58fe876..d8a3e648b 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -35,7 +35,7 @@ from common.util.common import post from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock -from common.util.rsa_util import decrypt +from common.util.rsa_util import rsa_long_decrypt from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping from dataset.serializers.paragraph_serializers import ParagraphSerializers from setting.models import Model @@ -195,7 +195,8 @@ class ChatSerializers(serializers.Serializer): if model is not None: chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, json.loads( - decrypt(model.credential)), + rsa_long_decrypt( + model.credential)), streaming=True) chat_id = str(uuid.uuid1()) @@ -213,7 +214,8 @@ class ChatSerializers(serializers.Serializer): id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) - model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.uuid("模型id")) multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("多轮会话")) @@ -246,14 +248,18 @@ class ChatSerializers(serializers.Serializer): def open(self): user_id = self.is_valid(raise_exception=True) chat_id = str(uuid.uuid1()) - model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() - if model is None: - raise AppApiException(500, "模型不存在") + model_id = self.data.get('model_id') + if model_id is not None and len(model_id) > 0: + model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() + chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, + json.loads( + rsa_long_decrypt( + model.credential)), + streaming=True) + else: + model = None + chat_model = None dataset_id_list = self.data.get('dataset_id_list') - chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, - json.loads( - decrypt(model.credential)), - streaming=True) application = Application(id=None, dialogue_number=3, model=model, dataset_setting=self.data.get('dataset_setting'), model_setting=self.data.get('model_setting'), diff --git a/apps/common/util/rsa_util.py b/apps/common/util/rsa_util.py index ee93bf499..003018672 100644 --- a/apps/common/util/rsa_util.py +++ b/apps/common/util/rsa_util.py @@ -62,18 +62,6 @@ def get_key_pair_by_sql(): return system_setting.meta -# def get_key_pair(): -# if not os.path.exists("/opt/maxkb/conf/receiver.pem"): -# kv = generate() -# private_file_out = open("/opt/maxkb/conf/private.pem", "wb") -# private_file_out.write(kv.get('value')) -# private_file_out.close() -# receiver_file_out = open("/opt/maxkb/conf/receiver.pem", "wb") -# receiver_file_out.write(kv.get('key')) -# receiver_file_out.close() -# return {'key': open("/opt/maxkb/conf/receiver.pem").read(), 'value': open("/opt/maxkb/conf/private.pem").read()} - - def encrypt(msg, public_key: str | None = None): """ 加密 @@ -100,3 +88,53 @@ def decrypt(msg, pri_key: str | None = None): cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) decrypt_data = cipher.decrypt(base64.b64decode(msg), 0) return decrypt_data.decode("utf-8") + + +def rsa_long_encrypt(message, public_key: str | None = None, length=200): + """ + 超长文本加密 + + :param message: 需要加密的字符串 + :param public_key 公钥 + :param length: 1024bit的证书用100, 2048bit的证书用 200 + :return: 加密后的数据 + """ + # 读取公钥 + if public_key is None: + public_key = get_key_pair().get('key') + cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, + passphrase=secret_code)) + # 处理:Plaintext is too long. 分段加密 + if len(message) <= length: + # 对编码的数据进行加密,并通过base64进行编码 + result = base64.b64encode(cipher.encrypt(message.encode('utf-8'))) + else: + rsa_text = [] + # 对编码后的数据进行切片,原因:加密长度不能过长 + for i in range(0, len(message), length): + cont = message[i:i + length] + # 对切片后的数据进行加密,并新增到text后面 + rsa_text.append(cipher.encrypt(cont.encode('utf-8'))) + # 加密完进行拼接 + cipher_text = b''.join(rsa_text) + # base64进行编码 + result = base64.b64encode(cipher_text) + return result.decode() + + +def rsa_long_decrypt(message, pri_key: str | None = None, length=256): + """ + 超长文本解密,默认不加密 + :param message: 需要解密的数据 + :param pri_key: 秘钥 + :param length : 1024bit的证书用128,2048bit证书用256位 + :return: 解密后的数据 + """ + if pri_key is None: + pri_key = get_key_pair().get('value') + cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + base64_de = base64.b64decode(message) + res = [] + for i in range(0, len(base64_de), length): + res.append(cipher.decrypt(base64_de[i:i + length], 0)) + return b"".join(res).decode() diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index f2acfb9df..919172d13 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -164,6 +164,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): elif target_dataset.type == Type.base.value and dataset.type == Type.web.value: document_list.update(dataset_id=target_dataset_id, type=Type.base, meta={}) + else: + document_list.update(dataset_id=target_dataset_id) paragraph_list.update(dataset_id=target_dataset_id) ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs( [problem_paragraph_mapping.id for problem_paragraph_mapping in problem_paragraph_mapping_list], @@ -713,6 +715,19 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): ListenerManagement.delete_embedding_by_document_list_signal.send(document_id_list) return True + def batch_edit_hit_handling(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True) + hit_handling_method = instance.get('hit_handling_method') + if hit_handling_method is None: + raise AppApiException(500, '命中处理方式必填') + if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return': + raise AppApiException(500, '命中处理方式必须为directly_return|optimization') + self.is_valid(raise_exception=True) + document_id_list = instance.get("id_list") + hit_handling_method = instance.get('hit_handling_method') + QuerySet(Document).filter(id__in=document_id_list).update(hit_handling_method=hit_handling_method) + class FileBufferHandle: buffer = None diff --git a/apps/dataset/swagger_api/document_api.py b/apps/dataset/swagger_api/document_api.py new file mode 100644 index 000000000..1463a61c2 --- /dev/null +++ b/apps/dataset/swagger_api/document_api.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: document_api.py + @date:2024/4/28 13:56 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class DocumentApi(ApiMixin): + class BatchEditHitHandlingApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), + title="主键id列表", + description="主键id列表"), + 'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式", + description="directly_return|optimization") + } + ) diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 5ed09a199..be68ccdc4 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -14,6 +14,7 @@ urlpatterns = [ path('dataset//document', views.Document.as_view(), name='document'), path('dataset//document/web', views.WebDocument.as_view()), path('dataset//document/_bach', views.Document.Batch.as_view()), + path('dataset//document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()), path('dataset//document//', views.Document.Page.as_view()), path('dataset//document/', views.Document.Operate.as_view(), name="document_operate"), diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index fd6797b01..a727a31fa 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -19,6 +19,7 @@ from common.response import result from common.util.common import query_params_to_single_dict from dataset.serializers.common_serializers import BatchSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentWebInstanceSerializer +from dataset.swagger_api.document_api import DocumentApi class WebDocument(APIView): @@ -71,6 +72,24 @@ class Document(APIView): d.is_valid(raise_exception=True) return result.success(d.list()) + class BatchEditHitHandling(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量修改文档命中处理方式", + operation_id="批量修改文档命中处理方式", + request_body= + DocumentApi.BatchEditHitHandlingApi.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_edit_hit_handling(request.data)) + class Batch(APIView): authentication_classes = [TokenAuth] diff --git a/apps/embedding/migrations/0002_embedding_search_vector.py b/apps/embedding/migrations/0002_embedding_search_vector.py index 3ed58d582..7d06d6046 100644 --- a/apps/embedding/migrations/0002_embedding_search_vector.py +++ b/apps/embedding/migrations/0002_embedding_search_vector.py @@ -18,27 +18,30 @@ def update_embedding_search_vector(embedding, paragraph_list): def save_keywords(apps, schema_editor): - document = apps.get_model("dataset", "Document") - embedding = apps.get_model("embedding", "Embedding") - paragraph = apps.get_model('dataset', 'Paragraph') - db_alias = schema_editor.connection.alias - document_list = document.objects.using(db_alias).all() - for document in document_list: - document.status = Status.embedding - document.save() - paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all() - embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector', - 'paragraph') - embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding - in embedding_list] - child_array = sub_array(embedding_update_list, 50) - for c in child_array: - try: - embedding.objects.using(db_alias).bulk_update(c, ['search_vector']) - except Exception as e: - print(e) - document.status = Status.success - document.save() + try: + document = apps.get_model("dataset", "Document") + embedding = apps.get_model("embedding", "Embedding") + paragraph = apps.get_model('dataset', 'Paragraph') + db_alias = schema_editor.connection.alias + document_list = document.objects.using(db_alias).all() + for document in document_list: + document.status = Status.embedding + document.save() + paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all() + embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector', + 'paragraph') + embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding + in embedding_list] + child_array = sub_array(embedding_update_list, 50) + for c in child_array: + try: + embedding.objects.using(db_alias).bulk_update(c, ['search_vector']) + except Exception as e: + print(e) + document.status = Status.success + document.save() + except Exception as e: + print(e) class Migration(migrations.Migration): diff --git a/apps/setting/migrations/0004_alter_model_credential.py b/apps/setting/migrations/0004_alter_model_credential.py new file mode 100644 index 000000000..4b5e48858 --- /dev/null +++ b/apps/setting/migrations/0004_alter_model_credential.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-04-28 18:06 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0003_model_meta_model_status'), + ] + + operations = [ + migrations.AlterField( + model_name='model', + name='credential', + field=models.CharField(max_length=102400, verbose_name='模型认证信息'), + ), + ] diff --git a/apps/setting/models/model_management.py b/apps/setting/models/model_management.py index d97100815..5bdd1b296 100644 --- a/apps/setting/models/model_management.py +++ b/apps/setting/models/model_management.py @@ -42,7 +42,7 @@ class Model(AppModelMixin): provider = models.CharField(max_length=128, verbose_name='供应商') - credential = models.CharField(max_length=5120, verbose_name="模型认证信息") + credential = models.CharField(max_length=102400, verbose_name="模型认证信息") meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict) diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index 2ac28343a..351a98c8a 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -18,7 +18,7 @@ from rest_framework import serializers from application.models import Application from common.exception.app_exception import AppApiException from common.util.field_message import ErrMessage -from common.util.rsa_util import encrypt, decrypt +from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt from setting.models.model_management import Model, Status from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus from setting.models_provider.constants.model_provider_constants import ModelProvideConstants @@ -118,7 +118,7 @@ class ModelSerializer(serializers.Serializer): model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, model_name) - source_model_credential = json.loads(decrypt(model.credential)) + source_model_credential = json.loads(rsa_long_decrypt(model.credential)) source_encryption_model_credential = model_credential.encryption_dict(source_model_credential) if credential is not None: for k in source_encryption_model_credential.keys(): @@ -170,7 +170,7 @@ class ModelSerializer(serializers.Serializer): model_name = self.data.get('model_name') model_credential_str = json.dumps(credential) model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, - credential=encrypt(model_credential_str), + credential=rsa_long_encrypt(model_credential_str), provider=provider, model_type=model_type, model_name=model_name) model.save() if status == Status.DOWNLOAD: @@ -180,7 +180,7 @@ class ModelSerializer(serializers.Serializer): @staticmethod def model_to_dict(model: Model): - credential = json.loads(decrypt(model.credential)) + credential = json.loads(rsa_long_decrypt(model.credential)) return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, 'model_name': model.model_name, 'status': model.status, @@ -252,7 +252,7 @@ class ModelSerializer(serializers.Serializer): if update_key in instance and instance.get(update_key) is not None: if update_key == 'credential': model_credential_str = json.dumps(credential) - model.__setattr__(update_key, encrypt(model_credential_str)) + model.__setattr__(update_key, rsa_long_encrypt(model_credential_str)) else: model.__setattr__(update_key, instance.get(update_key)) model.save() diff --git a/ui/src/api/dataset.ts b/ui/src/api/dataset.ts index 850626cf4..702731a26 100644 --- a/ui/src/api/dataset.ts +++ b/ui/src/api/dataset.ts @@ -73,7 +73,7 @@ const postDataset: (data: datasetData, loading?: Ref) => Promise { - return post(`${prefix}`, data, undefined, loading) + return post(`${prefix}`, data, undefined, loading, 1000 * 60 * 5) } /** diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index fdd070573..2d2fc1f65 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -10,7 +10,7 @@ const prefix = '/dataset' * @param 参数 file:file,limit:number,patterns:array,with_filter:boolean */ const postSplitDocument: (data: any) => Promise> = (data) => { - return post(`${prefix}/document/split`, data) + return post(`${prefix}/document/split`, data, undefined, undefined, 1000 * 60 * 60) } /** @@ -80,7 +80,7 @@ const postDocument: ( data: any, loading?: Ref ) => Promise> = (dataset_id, data, loading) => { - return post(`${prefix}/${dataset_id}/document/_bach`, data, {}, loading) + return post(`${prefix}/${dataset_id}/document/_bach`, data, {}, loading, 1000 * 60 * 5) } /** @@ -206,6 +206,20 @@ const putMigrateMulDocument: ( ) } +/** + * 批量修改命中方式 + * @param dataset_id 知识库id + * @param data {id_list:[],hit_handling_method:'directly_return|optimization'} + * @param loading + * @returns + */ +const batchEditHitHandling: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return put(`${prefix}/${dataset_id}/document/batch_hit_handling`, data, undefined, loading) +} export default { postSplitDocument, getDocument, @@ -219,5 +233,6 @@ export default { putDocumentRefresh, delMulSyncDocument, postWebDocument, - putMigrateMulDocument + putMigrateMulDocument, + batchEditHitHandling } diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index b39ed6707..6da9dd84a 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -56,7 +56,7 @@ export class ChatRecordManage { this.chat.answer_text = this.chat.answer_text + this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('') } else if (this.is_close) { - this.chat.answer_text = this.chat.answer_text + this.chat.buffer.join('') + this.chat.answer_text = this.chat.answer_text + this.chat.buffer.splice(0).join('') this.chat.write_ed = true this.write_ed = true if (this.loading) { diff --git a/ui/src/components/ai-chat/ParagraphSourceDialog.vue b/ui/src/components/ai-chat/ParagraphSourceDialog.vue index 1b7cf2d70..2c8749dc3 100644 --- a/ui/src/components/ai-chat/ParagraphSourceDialog.vue +++ b/ui/src/components/ai-chat/ParagraphSourceDialog.vue @@ -5,6 +5,7 @@ v-model="dialogVisible" destroy-on-close append-to-body + align-center >
@@ -63,7 +64,7 @@ diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 8eb60639e..7a83408cc 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -224,7 +224,7 @@ const chartOpenId = ref('') const chatList = ref([]) const isDisabledChart = computed( - () => !(inputValue.value.trim() && (props.appId || (props.data?.name && props.data?.model_id))) + () => !(inputValue.value.trim() && (props.appId || props.data?.name)) ) const isMdArray = (val: string) => val.match(/^-\s.*/m) const prologueList = computed(() => { @@ -274,12 +274,7 @@ function openParagraph(row: any, id?: string) { } function quickProblemHandle(val: string) { - if (!props.log && !loading.value && props.data?.name && props.data?.model_id) { - // inputValue.value = val - // nextTick(() => { - // quickInputRef.value?.focus() - // }) - + if (!loading.value && props.data?.name) { handleDebounceClick(val) } } @@ -509,16 +504,14 @@ function regenerationChart(item: chatType) { } function getSourceDetail(row: any) { - logApi - .getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading) - .then((res) => { - const exclude_keys = ['answer_text', 'id'] - Object.keys(res.data).forEach((key) => { - if (!exclude_keys.includes(key)) { - row[key] = res.data[key] - } - }) + logApi.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading).then((res) => { + const exclude_keys = ['answer_text', 'id'] + Object.keys(res.data).forEach((key) => { + if (!exclude_keys.includes(key)) { + row[key] = res.data[key] + } }) + }) return true } diff --git a/ui/src/request/index.ts b/ui/src/request/index.ts index 6c83f3ed6..7d748ffd8 100644 --- a/ui/src/request/index.ts +++ b/ui/src/request/index.ts @@ -119,9 +119,15 @@ const promise: ( export const get: ( url: string, params?: unknown, - loading?: NProgress | Ref -) => Promise> = (url: string, params: unknown, loading?: NProgress | Ref) => { - return promise(request({ url: url, method: 'get', params }), loading) + loading?: NProgress | Ref, + timeout?: number +) => Promise> = ( + url: string, + params: unknown, + loading?: NProgress | Ref, + timeout?: number +) => { + return promise(request({ url: url, method: 'get', params, timeout: timeout }), loading) } /** @@ -136,9 +142,10 @@ export const post: ( url: string, data?: unknown, params?: unknown, - loading?: NProgress | Ref -) => Promise | any> = (url, data, params, loading) => { - return promise(request({ url: url, method: 'post', data, params }), loading) + loading?: NProgress | Ref, + timeout?: number +) => Promise | any> = (url, data, params, loading, timeout) => { + return promise(request({ url: url, method: 'post', data, params, timeout }), loading) } /**| @@ -153,9 +160,10 @@ export const put: ( url: string, data?: unknown, params?: unknown, - loading?: NProgress | Ref -) => Promise> = (url, data, params, loading) => { - return promise(request({ url: url, method: 'put', data, params }), loading) + loading?: NProgress | Ref, + timeout?: number +) => Promise> = (url, data, params, loading, timeout) => { + return promise(request({ url: url, method: 'put', data, params, timeout }), loading) } /** @@ -169,9 +177,10 @@ export const del: ( url: string, params?: unknown, data?: unknown, - loading?: NProgress | Ref -) => Promise> = (url, params, data, loading) => { - return promise(request({ url: url, method: 'delete', params, data }), loading) + loading?: NProgress | Ref, + timeout?: number +) => Promise> = (url, params, data, loading, timeout) => { + return promise(request({ url: url, method: 'delete', params, data, timeout }), loading) } /** diff --git a/ui/src/views/application-overview/component/EditAvatarDialog.vue b/ui/src/views/application-overview/component/EditAvatarDialog.vue index e2f2e915d..9b1d3f840 100644 --- a/ui/src/views/application-overview/component/EditAvatarDialog.vue +++ b/ui/src/views/application-overview/component/EditAvatarDialog.vue @@ -35,7 +35,7 @@ accept="image/*" :on-change="onChange" > - 上传 + 上传
diff --git a/ui/src/views/application/CreateAndSetting.vue b/ui/src/views/application/CreateAndSetting.vue index 5612f5561..ee666229c 100644 --- a/ui/src/views/application/CreateAndSetting.vue +++ b/ui/src/views/application/CreateAndSetting.vue @@ -48,7 +48,7 @@ >({ name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }], model_id: [ { - required: true, + required: false, message: '请选择模型', trigger: 'change' } diff --git a/ui/src/views/application/components/ParamSettingDialog.vue b/ui/src/views/application/components/ParamSettingDialog.vue index d5db1e41e..92e22c274 100644 --- a/ui/src/views/application/components/ParamSettingDialog.vue +++ b/ui/src/views/application/components/ParamSettingDialog.vue @@ -251,7 +251,7 @@ defineExpose({ open }) padding: 0 !important; } .dialog-max-height { - height: calc(100vh - 180px); + height: 550px; } .custom-slider { .el-input-number.is-without-controls .el-input__wrapper { diff --git a/ui/src/views/document/component/BatchEditDocumentDialog.vue b/ui/src/views/document/component/BatchEditDocumentDialog.vue new file mode 100644 index 000000000..55c0de0a5 --- /dev/null +++ b/ui/src/views/document/component/BatchEditDocumentDialog.vue @@ -0,0 +1,96 @@ + + + diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue index 02a0434ec..8e5a5a4bc 100644 --- a/ui/src/views/document/index.vue +++ b/ui/src/views/document/index.vue @@ -21,10 +21,13 @@ >同步文档 批量迁移迁移 + 设置 批量删除删除
@@ -212,6 +215,10 @@ + @@ -225,6 +232,7 @@ import documentApi from '@/api/document' import ImportDocumentDialog from './component/ImportDocumentDialog.vue' import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue' import SelectDatasetDialog from './component/SelectDatasetDialog.vue' +import BatchEditDocumentDialog from './component/BatchEditDocumentDialog.vue' import { numberFormat } from '@/utils/utils' import { datetimeFormat } from '@/utils/time' import { hitHandlingMethod } from './utils' @@ -257,7 +265,7 @@ onBeforeRouteLeave((to: any, from: any) => { }) const beforePagination = computed(() => common.paginationConfig[storeKey]) const beforeSearch = computed(() => common.search[storeKey]) - +const batchEditDocumentDialogRef = ref>() const SyncWebDialogRef = ref() const loading = ref(false) let interval: any @@ -317,6 +325,13 @@ const handleSelectionChange = (val: any[]) => { multipleSelection.value = val } +function openBatchEditDocument() { + const arr: string[] = multipleSelection.value.map((v) => v.id) + if (batchEditDocumentDialogRef) { + batchEditDocumentDialogRef?.value?.open(arr) + } +} + /** * 初始化轮询 */ @@ -356,9 +371,9 @@ function refreshDocument(row: any) { .catch(() => {}) } } else { - // documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => { - // getList() - // }) + documentApi.putDocumentRefresh(row.dataset_id, row.id).then((res) => { + getList() + }) } } diff --git a/ui/src/views/log/component/ChatRecordDrawer.vue b/ui/src/views/log/component/ChatRecordDrawer.vue index c80933411..7471dd182 100644 --- a/ui/src/views/log/component/ChatRecordDrawer.vue +++ b/ui/src/views/log/component/ChatRecordDrawer.vue @@ -1,7 +1,7 @@