From 5705f3c4a819a4407cd5530a4af8fd312308617f Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Sun, 28 Apr 2024 14:35:51 +0800 Subject: [PATCH 01/21] =?UTF-8?q?feat:=20=E6=96=87=E6=A1=A3=E6=89=B9?= =?UTF-8?q?=E9=87=8F=E8=AE=BE=E7=BD=AE=E5=91=BD=E4=B8=AD=E5=A4=84=E7=90=86?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=20(#293)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/document_serializers.py | 13 +++ apps/dataset/swagger_api/document_api.py | 27 ++++++ apps/dataset/urls.py | 1 + apps/dataset/views/document.py | 19 ++++ ui/src/api/document.ts | 17 +++- .../component/BatchEditDocumentDialog.vue | 96 +++++++++++++++++++ ui/src/views/document/index.vue | 21 +++- 7 files changed, 190 insertions(+), 4 deletions(-) create mode 100644 apps/dataset/swagger_api/document_api.py create mode 100644 ui/src/views/document/component/BatchEditDocumentDialog.vue diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index f2acfb9df..7d0149468 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -713,6 +713,19 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): ListenerManagement.delete_embedding_by_document_list_signal.send(document_id_list) return True + def batch_edit_hit_handling(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True) + hit_handling_method = instance.get('hit_handling_method') + if hit_handling_method is None: + raise AppApiException(500, '命中处理方式必填') + if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return': + raise AppApiException(500, '命中处理方式必须为directly_return|optimization') + self.is_valid(raise_exception=True) + document_id_list = instance.get("id_list") + hit_handling_method = instance.get('hit_handling_method') + QuerySet(Document).filter(id__in=document_id_list).update(hit_handling_method=hit_handling_method) + class FileBufferHandle: buffer = None diff --git a/apps/dataset/swagger_api/document_api.py b/apps/dataset/swagger_api/document_api.py new file mode 100644 index 000000000..1463a61c2 --- /dev/null +++ b/apps/dataset/swagger_api/document_api.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: document_api.py + @date:2024/4/28 13:56 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class DocumentApi(ApiMixin): + class BatchEditHitHandlingApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), + title="主键id列表", + description="主键id列表"), + 'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式", + description="directly_return|optimization") + } + ) diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 5ed09a199..be68ccdc4 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -14,6 +14,7 @@ urlpatterns = [ path('dataset//document', views.Document.as_view(), name='document'), path('dataset//document/web', views.WebDocument.as_view()), path('dataset//document/_bach', views.Document.Batch.as_view()), + path('dataset//document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()), path('dataset//document//', views.Document.Page.as_view()), path('dataset//document/', views.Document.Operate.as_view(), name="document_operate"), diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index fd6797b01..a727a31fa 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -19,6 +19,7 @@ from common.response import result from common.util.common import query_params_to_single_dict from dataset.serializers.common_serializers import BatchSerializer from dataset.serializers.document_serializers import DocumentSerializers, DocumentWebInstanceSerializer +from dataset.swagger_api.document_api import DocumentApi class WebDocument(APIView): @@ -71,6 +72,24 @@ class Document(APIView): d.is_valid(raise_exception=True) return result.success(d.list()) + class BatchEditHitHandling(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量修改文档命中处理方式", + operation_id="批量修改文档命中处理方式", + request_body= + DocumentApi.BatchEditHitHandlingApi.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_edit_hit_handling(request.data)) + class Batch(APIView): authentication_classes = [TokenAuth] diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index fdd070573..8affd3742 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -206,6 +206,20 @@ const putMigrateMulDocument: ( ) } +/** + * 批量修改命中方式 + * @param dataset_id 知识库id + * @param data {id_list:[],hit_handling_method:'directly_return|optimization'} + * @param loading + * @returns + */ +const batchEditHitHandling: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return put(`${prefix}/${dataset_id}/document/batch_hit_handling`, data, undefined, loading) +} export default { postSplitDocument, getDocument, @@ -219,5 +233,6 @@ export default { putDocumentRefresh, delMulSyncDocument, postWebDocument, - putMigrateMulDocument + putMigrateMulDocument, + batchEditHitHandling } diff --git a/ui/src/views/document/component/BatchEditDocumentDialog.vue b/ui/src/views/document/component/BatchEditDocumentDialog.vue new file mode 100644 index 000000000..55c0de0a5 --- /dev/null +++ b/ui/src/views/document/component/BatchEditDocumentDialog.vue @@ -0,0 +1,96 @@ + + + diff --git a/ui/src/views/document/index.vue b/ui/src/views/document/index.vue index 02a0434ec..65a5f2cc1 100644 --- a/ui/src/views/document/index.vue +++ b/ui/src/views/document/index.vue @@ -21,10 +21,13 @@ >同步文档 批量迁移迁移 + 设置 批量删除删除 @@ -212,6 +215,10 @@ + @@ -225,6 +232,7 @@ import documentApi from '@/api/document' import ImportDocumentDialog from './component/ImportDocumentDialog.vue' import SyncWebDialog from '@/views/dataset/component/SyncWebDialog.vue' import SelectDatasetDialog from './component/SelectDatasetDialog.vue' +import BatchEditDocumentDialog from './component/BatchEditDocumentDialog.vue' import { numberFormat } from '@/utils/utils' import { datetimeFormat } from '@/utils/time' import { hitHandlingMethod } from './utils' @@ -257,7 +265,7 @@ onBeforeRouteLeave((to: any, from: any) => { }) const beforePagination = computed(() => common.paginationConfig[storeKey]) const beforeSearch = computed(() => common.search[storeKey]) - +const batchEditDocumentDialogRef = ref>() const SyncWebDialogRef = ref() const loading = ref(false) let interval: any @@ -317,6 +325,13 @@ const handleSelectionChange = (val: any[]) => { multipleSelection.value = val } +function openBatchEditDocument() { + const arr: string[] = multipleSelection.value.map((v) => v.id) + if (batchEditDocumentDialogRef) { + batchEditDocumentDialogRef?.value?.open(arr) + } +} + /** * 初始化轮询 */ From 7b5ccd9089580c070c3e7bc0e608000a9ddbe7bf Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Sun, 28 Apr 2024 17:09:12 +0800 Subject: [PATCH 02/21] =?UTF-8?q?perf:=20=E5=BA=94=E7=94=A8=E7=9A=84AI?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E4=BF=AE=E6=94=B9=E4=B8=BA=E4=B8=8D=E5=BF=85?= =?UTF-8?q?=E5=A1=AB=20(#297)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../step/chat_step/i_chat_step.py | 2 +- .../step/chat_step/impl/base_chat_step.py | 88 +++++++++---------- .../i_reset_problem_step.py | 2 +- .../impl/base_reset_problem_step.py | 2 + .../serializers/application_serializers.py | 29 +++--- .../serializers/chat_message_serializers.py | 4 +- .../serializers/chat_serializers.py | 20 +++-- ui/src/components/ai-chat/index.vue | 18 ++-- ui/src/views/application/CreateAndSetting.vue | 5 +- 9 files changed, 91 insertions(+), 79 deletions(-) diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 534b9d409..8fbac34c9 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -54,7 +54,7 @@ class IChatStep(IBaseChatPipelineStep): message_list = serializers.ListField(required=True, child=MessageField(required=True), error_messages=ErrMessage.list("对话列表")) # 大语言模型 - chat_model = ModelField(error_messages=ErrMessage.list("大语言模型")) + chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.list("大语言模型")) # 段落列表 paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表")) # 对话id diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 0919abbe2..b53f12941 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -126,6 +126,26 @@ class BaseChatStep(IChatStep): result.append({'role': 'ai', 'content': answer_text}) return result + @staticmethod + def get_stream_result(message_list: List[BaseMessage], + chat_model: BaseChatModel = None, + paragraph_list=None, + no_references_setting=None): + if paragraph_list is None: + paragraph_list = [] + directly_return_chunk_list = [AIMessageChunk(content=paragraph.content) + for paragraph in paragraph_list if + paragraph.hit_handling_method == 'directly_return'] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + return iter(directly_return_chunk_list), False + elif no_references_setting.get( + 'status') == 'designated_answer': + return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False + if chat_model is None: + return iter([AIMessageChunk('抱歉,没有在知识库中查询到相关信息。')]), False + else: + return chat_model.stream(message_list), True + def execute_stream(self, message_list: List[BaseMessage], chat_id, problem_text, @@ -136,29 +156,8 @@ class BaseChatStep(IChatStep): padding_problem_text: str = None, client_id=None, client_type=None, no_references_setting=None): - is_ai_chat = False - # 调用模型 - if chat_model is None: - chat_result = iter( - [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) for paragraph in paragraph_list]) - else: - if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get( - 'status') == 'designated_answer': - chat_result = iter([AIMessageChunk(content=no_references_setting.get('value'))]) - else: - if paragraph_list is not None and len(paragraph_list) > 0: - directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) - for paragraph in paragraph_list if - paragraph.hit_handling_method == 'directly_return'] - if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: - chat_result = iter(directly_return_chunk_list) - else: - chat_result = chat_model.stream(message_list) - is_ai_chat = True - else: - chat_result = chat_model.stream(message_list) - is_ai_chat = True - + chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, + no_references_setting) chat_record_id = uuid.uuid1() r = StreamingHttpResponse( streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, @@ -169,6 +168,27 @@ class BaseChatStep(IChatStep): r['Cache-Control'] = 'no-cache' return r + @staticmethod + def get_block_result(message_list: List[BaseMessage], + chat_model: BaseChatModel = None, + paragraph_list=None, + no_references_setting=None): + if paragraph_list is None: + paragraph_list = [] + + directly_return_chunk_list = [AIMessage(content=paragraph.content) + for paragraph in paragraph_list if + paragraph.hit_handling_method == 'directly_return'] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + return directly_return_chunk_list[0], False + elif no_references_setting.get( + 'status') == 'designated_answer': + return AIMessage(no_references_setting.get('value')), False + if chat_model is None: + return AIMessage('抱歉,没有在知识库中查询到相关信息。'), False + else: + return chat_model.invoke(message_list), True + def execute_block(self, message_list: List[BaseMessage], chat_id, problem_text, @@ -178,28 +198,8 @@ class BaseChatStep(IChatStep): manage: PipelineManage = None, padding_problem_text: str = None, client_id=None, client_type=None, no_references_setting=None): - is_ai_chat = False # 调用模型 - if chat_model is None: - chat_result = AIMessage( - content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list])) - else: - if (paragraph_list is None or len(paragraph_list) == 0) and no_references_setting.get( - 'status') == 'designated_answer': - chat_result = AIMessage(content=no_references_setting.get('value')) - else: - if paragraph_list is not None and len(paragraph_list) > 0: - directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) - for paragraph in paragraph_list if - paragraph.hit_handling_method == 'directly_return'] - if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: - chat_result = iter(directly_return_chunk_list) - else: - chat_result = chat_model.invoke(message_list) - is_ai_chat = True - else: - chat_result = chat_model.invoke(message_list) - is_ai_chat = True + chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, no_references_setting) chat_record_id = uuid.uuid1() if is_ai_chat: request_token = chat_model.get_num_tokens_from_messages(message_list) diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py index 930fd2482..ce30d96af 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py @@ -28,7 +28,7 @@ class IResetProblemStep(IBaseChatPipelineStep): history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), error_messages=ErrMessage.list("历史对答")) # 大语言模型 - chat_model = ModelField(error_messages=ErrMessage.base("大语言模型")) + chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型")) def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: return self.InstanceSerializer diff --git a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py index 2386be4fe..aad66446c 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py @@ -22,6 +22,8 @@ prompt = ( class BaseResetProblemStep(IResetProblemStep): def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None, **kwargs) -> str: + if chat_model is None: + return problem_text start_index = len(history_chat_record) - 3 history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] for index in diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 3307d873a..164e0083f 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -47,7 +47,8 @@ chat_cache = cache.caches['chat_cache'] class ModelDatasetAssociation(serializers.Serializer): user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) - model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("模型id")) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True, error_messages=ErrMessage.uuid( "知识库id")), @@ -57,8 +58,9 @@ class ModelDatasetAssociation(serializers.Serializer): super().is_valid(raise_exception=True) model_id = self.data.get('model_id') user_id = self.data.get('user_id') - if not QuerySet(Model).filter(id=model_id).exists(): - raise AppApiException(500, f'模型不存在【{model_id}】') + if model_id is not None and len(model_id) > 0: + if not QuerySet(Model).filter(id=model_id).exists(): + raise AppApiException(500, f'模型不存在【{model_id}】') dataset_id_list = list(set(self.data.get('dataset_id_list'))) exist_dataset_id_list = [str(dataset.id) for dataset in QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)] @@ -109,7 +111,8 @@ class ApplicationSerializer(serializers.Serializer): desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=256, min_length=1, error_messages=ErrMessage.char("应用描述")) - model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("模型")) multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话")) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024, error_messages=ErrMessage.char("开场白")) @@ -254,7 +257,8 @@ class ApplicationSerializer(serializers.Serializer): error_messages=ErrMessage.char("应用名称")) desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True, error_messages=ErrMessage.char("应用描述")) - model_id = serializers.CharField(required=False, error_messages=ErrMessage.char("模型")) + model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("模型")) multiple_rounds_dialogue = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("多轮会话")) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=1024, @@ -494,13 +498,14 @@ class ApplicationSerializer(serializers.Serializer): application_id = self.data.get("application_id") application = QuerySet(Application).get(id=application_id) - - model = QuerySet(Model).filter( - id=instance.get('model_id') if 'model_id' in instance else application.model_id, - user_id=application.user_id).first() - if model is None: - raise AppApiException(500, "模型不存在") - + if instance.get('model_id') is None or len(instance.get('model_id')) == 0: + application.model_id = None + else: + model = QuerySet(Model).filter( + id=instance.get('model_id'), + user_id=application.user_id).first() + if model is None: + raise AppApiException(500, "模型不存在") update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', 'dataset_setting', 'model_setting', 'problem_optimization', 'api_key_is_active', 'icon'] diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 47b905a77..7c89e9de4 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -167,9 +167,11 @@ class ChatMessageSerializer(serializers.Serializer): chat_cache.set(chat_id, chat_info, timeout=60 * 30) model = chat_info.application.model + if model is None: + return chat_info model = QuerySet(Model).filter(id=model.id).first() if model is None: - raise AppApiException(500, "模型不存在") + return chat_info if model.status == Status.ERROR: raise AppApiException(500, "当前模型不可用") if model.status == Status.DOWNLOAD: diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 5f58fe876..7476219dd 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -213,7 +213,8 @@ class ChatSerializers(serializers.Serializer): id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) - model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.uuid("模型id")) multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("多轮会话")) @@ -246,14 +247,17 @@ class ChatSerializers(serializers.Serializer): def open(self): user_id = self.is_valid(raise_exception=True) chat_id = str(uuid.uuid1()) - model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() - if model is None: - raise AppApiException(500, "模型不存在") + model_id = self.data.get('model_id') + if model_id is not None and len(model_id) > 0: + model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() + chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, + json.loads( + decrypt(model.credential)), + streaming=True) + else: + model = None + chat_model = None dataset_id_list = self.data.get('dataset_id_list') - chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, - json.loads( - decrypt(model.credential)), - streaming=True) application = Application(id=None, dialogue_number=3, model=model, dataset_setting=self.data.get('dataset_setting'), model_setting=self.data.get('model_setting'), diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 8eb60639e..48ffc1151 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -224,7 +224,7 @@ const chartOpenId = ref('') const chatList = ref([]) const isDisabledChart = computed( - () => !(inputValue.value.trim() && (props.appId || (props.data?.name && props.data?.model_id))) + () => !(inputValue.value.trim() && (props.appId || props.data?.name)) ) const isMdArray = (val: string) => val.match(/^-\s.*/m) const prologueList = computed(() => { @@ -509,16 +509,14 @@ function regenerationChart(item: chatType) { } function getSourceDetail(row: any) { - logApi - .getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading) - .then((res) => { - const exclude_keys = ['answer_text', 'id'] - Object.keys(res.data).forEach((key) => { - if (!exclude_keys.includes(key)) { - row[key] = res.data[key] - } - }) + logApi.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading).then((res) => { + const exclude_keys = ['answer_text', 'id'] + Object.keys(res.data).forEach((key) => { + if (!exclude_keys.includes(key)) { + row[key] = res.data[key] + } }) + }) return true } diff --git a/ui/src/views/application/CreateAndSetting.vue b/ui/src/views/application/CreateAndSetting.vue index 5612f5561..ee666229c 100644 --- a/ui/src/views/application/CreateAndSetting.vue +++ b/ui/src/views/application/CreateAndSetting.vue @@ -48,7 +48,7 @@ >({ name: [{ required: true, message: '请输入应用名称', trigger: 'blur' }], model_id: [ { - required: true, + required: false, message: '请选择模型', trigger: 'change' } From ef8def269d065e603c60f14839657a2d502fb168 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Sun, 28 Apr 2024 17:20:35 +0800 Subject: [PATCH 03/21] =?UTF-8?q?fix:=20=E5=BF=AB=E6=8D=B7=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E7=82=B9=E5=87=BB=E6=97=A0=E6=B3=95=E5=8F=91=E9=80=81?= =?UTF-8?q?=20(#298)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ui/src/components/ai-chat/index.vue | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 48ffc1151..7a83408cc 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -274,12 +274,7 @@ function openParagraph(row: any, id?: string) { } function quickProblemHandle(val: string) { - if (!props.log && !loading.value && props.data?.name && props.data?.model_id) { - // inputValue.value = val - // nextTick(() => { - // quickInputRef.value?.focus() - // }) - + if (!loading.value && props.data?.name) { handleDebounceClick(val) } } From 99483b1b4b7d5fd27f5054c9a358467332928e13 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Sun, 28 Apr 2024 17:59:02 +0800 Subject: [PATCH 04/21] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=E6=9C=AA?= =?UTF-8?q?=E9=85=8D=E7=BD=AEai=E6=A8=A1=E5=9E=8B=E5=9B=9E=E7=AD=94=20(#29?= =?UTF-8?q?9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat_pipeline/step/chat_step/impl/base_chat_step.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index b53f12941..d03f17b4f 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -142,7 +142,7 @@ class BaseChatStep(IChatStep): 'status') == 'designated_answer': return iter([AIMessageChunk(content=no_references_setting.get('value'))]), False if chat_model is None: - return iter([AIMessageChunk('抱歉,没有在知识库中查询到相关信息。')]), False + return iter([AIMessageChunk('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。')]), False else: return chat_model.stream(message_list), True @@ -185,7 +185,7 @@ class BaseChatStep(IChatStep): 'status') == 'designated_answer': return AIMessage(no_references_setting.get('value')), False if chat_model is None: - return AIMessage('抱歉,没有在知识库中查询到相关信息。'), False + return AIMessage('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。'), False else: return chat_model.invoke(message_list), True From e4bcebd5796f364d99c25b7fa2cc8b310ee79a95 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Sun, 28 Apr 2024 18:09:51 +0800 Subject: [PATCH 05/21] =?UTF-8?q?fix:=20API=5FKEY=20=E5=AD=97=E7=AC=A6?= =?UTF-8?q?=E9=95=BF=E5=BA=A6=E8=BF=87=E9=95=BF=E6=97=B6=EF=BC=8C=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=A8=A1=E5=9E=8B=E6=8A=A5=E9=94=99=20#278=20(#300)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../migrations/0004_alter_model_credential.py | 18 ++++++++++++++++++ apps/setting/models/model_management.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) create mode 100644 apps/setting/migrations/0004_alter_model_credential.py diff --git a/apps/setting/migrations/0004_alter_model_credential.py b/apps/setting/migrations/0004_alter_model_credential.py new file mode 100644 index 000000000..4b5e48858 --- /dev/null +++ b/apps/setting/migrations/0004_alter_model_credential.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-04-28 18:06 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0003_model_meta_model_status'), + ] + + operations = [ + migrations.AlterField( + model_name='model', + name='credential', + field=models.CharField(max_length=102400, verbose_name='模型认证信息'), + ), + ] diff --git a/apps/setting/models/model_management.py b/apps/setting/models/model_management.py index d97100815..5bdd1b296 100644 --- a/apps/setting/models/model_management.py +++ b/apps/setting/models/model_management.py @@ -42,7 +42,7 @@ class Model(AppModelMixin): provider = models.CharField(max_length=128, verbose_name='供应商') - credential = models.CharField(max_length=5120, verbose_name="模型认证信息") + credential = models.CharField(max_length=102400, verbose_name="模型认证信息") meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict) From e543acd2d1307b8003bf0e3ffb62abf78abb830e Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Sun, 28 Apr 2024 18:48:08 +0800 Subject: [PATCH 06/21] =?UTF-8?q?fix:=20=E3=80=90=E9=97=AE=E7=AD=94?= =?UTF-8?q?=E9=A1=B5=E9=9D=A2=E3=80=91=E6=8F=90=E9=97=AE=E8=BE=BE=E5=88=B0?= =?UTF-8?q?=E6=9C=80=E5=A4=A7=E9=99=90=E5=88=B6=E7=9A=84=E6=97=B6=E5=80=99?= =?UTF-8?q?=20=E8=BF=94=E5=9B=9E=E5=86=85=E5=AE=B9=E9=87=8D=E5=A4=8D?= =?UTF-8?q?=E3=80=82=20(#301)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ui/src/api/type/application.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index b39ed6707..6da9dd84a 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -56,7 +56,7 @@ export class ChatRecordManage { this.chat.answer_text = this.chat.answer_text + this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('') } else if (this.is_close) { - this.chat.answer_text = this.chat.answer_text + this.chat.buffer.join('') + this.chat.answer_text = this.chat.answer_text + this.chat.buffer.splice(0).join('') this.chat.write_ed = true this.write_ed = true if (this.loading) { From a5d58b7ec71bd9c4106cb8be3cefde2279b0ef2b Mon Sep 17 00:00:00 2001 From: wangdan Date: Sun, 28 Apr 2024 18:49:19 +0800 Subject: [PATCH 07/21] =?UTF-8?q?fix:=20=E6=A0=B7=E5=BC=8F=E4=BF=AE?= =?UTF-8?q?=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ui/src/components/ai-chat/ParagraphSourceDialog.vue | 13 +++++++++++-- .../application/components/ParamSettingDialog.vue | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/ui/src/components/ai-chat/ParagraphSourceDialog.vue b/ui/src/components/ai-chat/ParagraphSourceDialog.vue index 1b7cf2d70..2c8749dc3 100644 --- a/ui/src/components/ai-chat/ParagraphSourceDialog.vue +++ b/ui/src/components/ai-chat/ParagraphSourceDialog.vue @@ -5,6 +5,7 @@ v-model="dialogVisible" destroy-on-close append-to-body + align-center >
@@ -63,7 +64,7 @@ diff --git a/ui/src/views/application/components/ParamSettingDialog.vue b/ui/src/views/application/components/ParamSettingDialog.vue index d5db1e41e..92e22c274 100644 --- a/ui/src/views/application/components/ParamSettingDialog.vue +++ b/ui/src/views/application/components/ParamSettingDialog.vue @@ -251,7 +251,7 @@ defineExpose({ open }) padding: 0 !important; } .dialog-max-height { - height: calc(100vh - 180px); + height: 550px; } .custom-slider { .el-input-number.is-without-controls .el-input__wrapper { From 82d93631eabff21254246854aa3cd14c24c0da5a Mon Sep 17 00:00:00 2001 From: wangdan Date: Sun, 28 Apr 2024 19:00:30 +0800 Subject: [PATCH 08/21] =?UTF-8?q?perf:=20=E8=AE=BE=E7=BD=AE=E8=87=AA?= =?UTF-8?q?=E5=AE=9A=E4=B9=89Logo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../views/application-overview/component/EditAvatarDialog.vue | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ui/src/views/application-overview/component/EditAvatarDialog.vue b/ui/src/views/application-overview/component/EditAvatarDialog.vue index e2f2e915d..9b1d3f840 100644 --- a/ui/src/views/application-overview/component/EditAvatarDialog.vue +++ b/ui/src/views/application-overview/component/EditAvatarDialog.vue @@ -35,7 +35,7 @@ accept="image/*" :on-change="onChange" > - 上传 + 上传
From c71a1ae79bed3ae359f19ec1691cba7965952c12 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 29 Apr 2024 11:08:44 +0800 Subject: [PATCH 09/21] =?UTF-8?q?fix:=20=E5=A4=84=E7=90=86=E8=8E=B7?= =?UTF-8?q?=E5=8F=96tokens=E5=A4=B1=E8=B4=A5=E6=83=85=E5=86=B5=20(#308)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../step/chat_step/impl/base_chat_step.py | 8 ++++++-- .../reset_problem_step/impl/base_reset_problem_step.py | 10 ++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index d03f17b4f..74c60d1f0 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -59,8 +59,12 @@ def event_content(response, # 获取token if is_ai_chat: - request_token = chat_model.get_num_tokens_from_messages(message_list) - response_token = chat_model.get_num_tokens(all_text) + try: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(all_text) + except Exception as e: + request_token = 0 + response_token = 0 else: request_token = 0 response_token = 0 diff --git a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py index aad66446c..2ab675570 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py @@ -37,8 +37,14 @@ class BaseResetProblemStep(IResetProblemStep): response.content.index('') + 6:response.content.index('')] if padding_problem_data is not None and len(padding_problem_data.strip()) > 0: padding_problem = padding_problem_data - self.context['message_tokens'] = chat_model.get_num_tokens_from_messages(message_list) - self.context['answer_tokens'] = chat_model.get_num_tokens(padding_problem) + try: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(padding_problem) + except Exception as e: + request_token = 0 + response_token = 0 + self.context['message_tokens'] = request_token + self.context['answer_tokens'] = response_token return padding_problem def get_details(self, manage, **kwargs): From cd472b47f7f1326046986f9b3046231d6076ab98 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 29 Apr 2024 13:28:47 +0800 Subject: [PATCH 10/21] =?UTF-8?q?fix:=20=E6=A8=A1=E5=9E=8B=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E9=95=BF=E5=AD=97=E7=AC=A6=E7=9A=84=E5=8A=A0=E5=AF=86?= =?UTF-8?q?=E8=A7=A3=E5=AF=86=E6=96=B9=E5=BC=8F=20(#310)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/chat_message_serializers.py | 4 +- .../serializers/chat_serializers.py | 8 +-- apps/common/util/rsa_util.py | 52 +++++++++++++++++++ .../serializers/provider_serializers.py | 10 ++-- 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 7c89e9de4..3f45a934d 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -27,7 +27,7 @@ from application.models.api_key_model import ApplicationPublicAccessClient, Appl from common.constants.authentication_type import AuthenticationType from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed from common.util.field_message import ErrMessage -from common.util.rsa_util import decrypt +from common.util.rsa_util import rsa_long_decrypt from common.util.split_model import flat_map from dataset.models import Paragraph, Document from setting.models import Model, Status @@ -225,7 +225,7 @@ class ChatMessageSerializer(serializers.Serializer): # 对话模型 chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, json.loads( - decrypt(model.credential)), + rsa_long_decrypt(model.credential)), streaming=True) # 数据集id列表 dataset_id_list = [str(row.dataset_id) for row in diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index 7476219dd..d8a3e648b 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -35,7 +35,7 @@ from common.util.common import post from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from common.util.lock import try_lock, un_lock -from common.util.rsa_util import decrypt +from common.util.rsa_util import rsa_long_decrypt from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping from dataset.serializers.paragraph_serializers import ParagraphSerializers from setting.models import Model @@ -195,7 +195,8 @@ class ChatSerializers(serializers.Serializer): if model is not None: chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, json.loads( - decrypt(model.credential)), + rsa_long_decrypt( + model.credential)), streaming=True) chat_id = str(uuid.uuid1()) @@ -252,7 +253,8 @@ class ChatSerializers(serializers.Serializer): model = QuerySet(Model).filter(user_id=user_id, id=self.data.get('model_id')).first() chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name, json.loads( - decrypt(model.credential)), + rsa_long_decrypt( + model.credential)), streaming=True) else: model = None diff --git a/apps/common/util/rsa_util.py b/apps/common/util/rsa_util.py index ee93bf499..b808548b1 100644 --- a/apps/common/util/rsa_util.py +++ b/apps/common/util/rsa_util.py @@ -100,3 +100,55 @@ def decrypt(msg, pri_key: str | None = None): cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) decrypt_data = cipher.decrypt(base64.b64decode(msg), 0) return decrypt_data.decode("utf-8") + + +def rsa_long_encrypt(message, public_key: str | None = None, length=200): + """ + 超长文本加密 + + :param message: 需要加密的字符串 + :param public_key 公钥 + :param length: 1024bit的证书用100, 2048bit的证书用 200 + :return: 加密后的数据 + """ + + # 读取公钥 + if public_key is None: + public_key = get_key_pair().get('key') + cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, + passphrase=secret_code)) + # 处理:Plaintext is too long. 分段加密 + if len(message) <= length: + # 对编码的数据进行加密,并通过base64进行编码 + result = base64.b64encode(cipher.encrypt(message.encode('utf-8'))) + else: + rsa_text = [] + # 对编码后的数据进行切片,原因:加密长度不能过长 + for i in range(0, len(message), length): + cont = message[i:i + length] + # 对切片后的数据进行加密,并新增到text后面 + rsa_text.append(cipher.encrypt(cont.encode('utf-8'))) + # 加密完进行拼接 + cipher_text = b''.join(rsa_text) + # base64进行编码 + result = base64.b64encode(cipher_text) + return result.decode() + + +def rsa_long_decrypt(message, pri_key: str | None = None, length=256): + """ + 超长文本解密,默认不加密 + :param message: 需要解密的数据 + :param pri_key: 秘钥 + :param length : 1024bit的证书用128,2048bit证书用256位 + :return: 解密后的数据 + """ + + if pri_key is None: + pri_key = get_key_pair().get('value') + cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + base64_de = base64.b64decode(message) + res = [] + for i in range(0, len(base64_de), length): + res.append(cipher.decrypt(base64_de[i:i + length], 0)) + return b"".join(res).decode() diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index 2ac28343a..351a98c8a 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -18,7 +18,7 @@ from rest_framework import serializers from application.models import Application from common.exception.app_exception import AppApiException from common.util.field_message import ErrMessage -from common.util.rsa_util import encrypt, decrypt +from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt from setting.models.model_management import Model, Status from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus from setting.models_provider.constants.model_provider_constants import ModelProvideConstants @@ -118,7 +118,7 @@ class ModelSerializer(serializers.Serializer): model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, model_name) - source_model_credential = json.loads(decrypt(model.credential)) + source_model_credential = json.loads(rsa_long_decrypt(model.credential)) source_encryption_model_credential = model_credential.encryption_dict(source_model_credential) if credential is not None: for k in source_encryption_model_credential.keys(): @@ -170,7 +170,7 @@ class ModelSerializer(serializers.Serializer): model_name = self.data.get('model_name') model_credential_str = json.dumps(credential) model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, - credential=encrypt(model_credential_str), + credential=rsa_long_encrypt(model_credential_str), provider=provider, model_type=model_type, model_name=model_name) model.save() if status == Status.DOWNLOAD: @@ -180,7 +180,7 @@ class ModelSerializer(serializers.Serializer): @staticmethod def model_to_dict(model: Model): - credential = json.loads(decrypt(model.credential)) + credential = json.loads(rsa_long_decrypt(model.credential)) return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, 'model_name': model.model_name, 'status': model.status, @@ -252,7 +252,7 @@ class ModelSerializer(serializers.Serializer): if update_key in instance and instance.get(update_key) is not None: if update_key == 'credential': model_credential_str = json.dumps(credential) - model.__setattr__(update_key, encrypt(model_credential_str)) + model.__setattr__(update_key, rsa_long_encrypt(model_credential_str)) else: model.__setattr__(update_key, instance.get(update_key)) model.save() From c4c4934932d155f267cdcdb009094ad11cd42a84 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 29 Apr 2024 13:35:26 +0800 Subject: [PATCH 11/21] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E9=97=AE?= =?UTF-8?q?=E9=A2=98=E5=AD=97=E7=AC=A6=E8=B6=85=E8=BF=87256=E6=97=A0?= =?UTF-8?q?=E6=B3=95=E5=AD=98=E5=82=A8=E5=AF=B9=E8=AF=9D=E6=97=A5=E5=BF=97?= =?UTF-8?q?=20#305=20(#311)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...t_abstract_alter_chatrecord_answer_text.py | 23 +++++++++++++++++++ apps/application/models/application.py | 4 ++-- .../serializers/chat_message_serializers.py | 2 +- 3 files changed, 26 insertions(+), 3 deletions(-) create mode 100644 apps/application/migrations/0005_alter_chat_abstract_alter_chatrecord_answer_text.py diff --git a/apps/application/migrations/0005_alter_chat_abstract_alter_chatrecord_answer_text.py b/apps/application/migrations/0005_alter_chat_abstract_alter_chatrecord_answer_text.py new file mode 100644 index 000000000..0643a39ce --- /dev/null +++ b/apps/application/migrations/0005_alter_chat_abstract_alter_chatrecord_answer_text.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.13 on 2024-04-29 13:33 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0004_applicationaccesstoken_show_source'), + ] + + operations = [ + migrations.AlterField( + model_name='chat', + name='abstract', + field=models.CharField(max_length=1024, verbose_name='摘要'), + ), + migrations.AlterField( + model_name='chatrecord', + name='answer_text', + field=models.CharField(max_length=40960, verbose_name='答案'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index 774e1bdc2..6c77937bd 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -73,7 +73,7 @@ class ApplicationDatasetMapping(AppModelMixin): class Chat(AppModelMixin): id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") application = models.ForeignKey(Application, on_delete=models.CASCADE) - abstract = models.CharField(max_length=256, verbose_name="摘要") + abstract = models.CharField(max_length=1024, verbose_name="摘要") client_id = models.UUIDField(verbose_name="客户端id", default=None, null=True) class Meta: @@ -96,7 +96,7 @@ class ChatRecord(AppModelMixin): vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices, default=VoteChoices.UN_VOTE) problem_text = models.CharField(max_length=1024, verbose_name="问题") - answer_text = models.CharField(max_length=4096, verbose_name="答案") + answer_text = models.CharField(max_length=40960, verbose_name="答案") message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) const = models.IntegerField(verbose_name="总费用", default=0) diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 3f45a934d..f8c80a865 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -138,7 +138,7 @@ def get_post_handler(chat_info: ChatInfo): class ChatMessageSerializer(serializers.Serializer): chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id")) - message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题")) + message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"), max_length=1024) stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答")) re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答")) application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) From a788d8f3b8280bc6f4dc093d0587035451862e82 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 29 Apr 2024 13:53:49 +0800 Subject: [PATCH 12/21] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=E8=B6=85?= =?UTF-8?q?=E9=95=BF=E6=96=87=E6=9C=ACrsa=E5=8A=A0=E5=AF=86=E8=A7=A3?= =?UTF-8?q?=E5=AF=86=20(#312)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/util/rsa_util.py | 62 ++++++++++++++---------------------- 1 file changed, 24 insertions(+), 38 deletions(-) diff --git a/apps/common/util/rsa_util.py b/apps/common/util/rsa_util.py index b808548b1..003018672 100644 --- a/apps/common/util/rsa_util.py +++ b/apps/common/util/rsa_util.py @@ -62,18 +62,6 @@ def get_key_pair_by_sql(): return system_setting.meta -# def get_key_pair(): -# if not os.path.exists("/opt/maxkb/conf/receiver.pem"): -# kv = generate() -# private_file_out = open("/opt/maxkb/conf/private.pem", "wb") -# private_file_out.write(kv.get('value')) -# private_file_out.close() -# receiver_file_out = open("/opt/maxkb/conf/receiver.pem", "wb") -# receiver_file_out.write(kv.get('key')) -# receiver_file_out.close() -# return {'key': open("/opt/maxkb/conf/receiver.pem").read(), 'value': open("/opt/maxkb/conf/private.pem").read()} - - def encrypt(msg, public_key: str | None = None): """ 加密 @@ -111,28 +99,27 @@ def rsa_long_encrypt(message, public_key: str | None = None, length=200): :param length: 1024bit的证书用100, 2048bit的证书用 200 :return: 加密后的数据 """ - # 读取公钥 if public_key is None: public_key = get_key_pair().get('key') - cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, - passphrase=secret_code)) - # 处理:Plaintext is too long. 分段加密 - if len(message) <= length: - # 对编码的数据进行加密,并通过base64进行编码 - result = base64.b64encode(cipher.encrypt(message.encode('utf-8'))) - else: - rsa_text = [] - # 对编码后的数据进行切片,原因:加密长度不能过长 - for i in range(0, len(message), length): - cont = message[i:i + length] - # 对切片后的数据进行加密,并新增到text后面 - rsa_text.append(cipher.encrypt(cont.encode('utf-8'))) - # 加密完进行拼接 - cipher_text = b''.join(rsa_text) - # base64进行编码 - result = base64.b64encode(cipher_text) - return result.decode() + cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, + passphrase=secret_code)) + # 处理:Plaintext is too long. 分段加密 + if len(message) <= length: + # 对编码的数据进行加密,并通过base64进行编码 + result = base64.b64encode(cipher.encrypt(message.encode('utf-8'))) + else: + rsa_text = [] + # 对编码后的数据进行切片,原因:加密长度不能过长 + for i in range(0, len(message), length): + cont = message[i:i + length] + # 对切片后的数据进行加密,并新增到text后面 + rsa_text.append(cipher.encrypt(cont.encode('utf-8'))) + # 加密完进行拼接 + cipher_text = b''.join(rsa_text) + # base64进行编码 + result = base64.b64encode(cipher_text) + return result.decode() def rsa_long_decrypt(message, pri_key: str | None = None, length=256): @@ -143,12 +130,11 @@ def rsa_long_decrypt(message, pri_key: str | None = None, length=256): :param length : 1024bit的证书用128,2048bit证书用256位 :return: 解密后的数据 """ - if pri_key is None: pri_key = get_key_pair().get('value') - cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) - base64_de = base64.b64decode(message) - res = [] - for i in range(0, len(base64_de), length): - res.append(cipher.decrypt(base64_de[i:i + length], 0)) - return b"".join(res).decode() + cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + base64_de = base64.b64decode(message) + res = [] + for i in range(0, len(base64_de), length): + res.append(cipher.decrypt(base64_de[i:i + length], 0)) + return b"".join(res).decode() From 62c959f905f576cfef51300f73a8393887f28903 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 29 Apr 2024 14:09:58 +0800 Subject: [PATCH 13/21] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=E5=8E=86?= =?UTF-8?q?=E5=8F=B2=E6=95=B0=E6=8D=AE=E5=88=86=E8=AF=8D=E5=A4=84=E7=90=86?= =?UTF-8?q?=E5=85=9C=E5=BA=95=E6=93=8D=E4=BD=9C=20(#313)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../0002_embedding_search_vector.py | 45 ++++++++++--------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/apps/embedding/migrations/0002_embedding_search_vector.py b/apps/embedding/migrations/0002_embedding_search_vector.py index 3ed58d582..7d06d6046 100644 --- a/apps/embedding/migrations/0002_embedding_search_vector.py +++ b/apps/embedding/migrations/0002_embedding_search_vector.py @@ -18,27 +18,30 @@ def update_embedding_search_vector(embedding, paragraph_list): def save_keywords(apps, schema_editor): - document = apps.get_model("dataset", "Document") - embedding = apps.get_model("embedding", "Embedding") - paragraph = apps.get_model('dataset', 'Paragraph') - db_alias = schema_editor.connection.alias - document_list = document.objects.using(db_alias).all() - for document in document_list: - document.status = Status.embedding - document.save() - paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all() - embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector', - 'paragraph') - embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding - in embedding_list] - child_array = sub_array(embedding_update_list, 50) - for c in child_array: - try: - embedding.objects.using(db_alias).bulk_update(c, ['search_vector']) - except Exception as e: - print(e) - document.status = Status.success - document.save() + try: + document = apps.get_model("dataset", "Document") + embedding = apps.get_model("embedding", "Embedding") + paragraph = apps.get_model('dataset', 'Paragraph') + db_alias = schema_editor.connection.alias + document_list = document.objects.using(db_alias).all() + for document in document_list: + document.status = Status.embedding + document.save() + paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all() + embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector', + 'paragraph') + embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding + in embedding_list] + child_array = sub_array(embedding_update_list, 50) + for c in child_array: + try: + embedding.objects.using(db_alias).bulk_update(c, ['search_vector']) + except Exception as e: + print(e) + document.status = Status.success + document.save() + except Exception as e: + print(e) class Migration(migrations.Migration): From e5e2ece1bd8e764fb370338e3451aa0bdedbdc7d Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 29 Apr 2024 14:23:32 +0800 Subject: [PATCH 14/21] =?UTF-8?q?perf:=20=E5=AF=B9=E8=AF=9D=E6=97=A5?= =?UTF-8?q?=E5=BF=97=E6=91=98=E8=A6=81=E6=A0=B7=E5=BC=8F=E4=BC=98=E5=8C=96?= =?UTF-8?q?=20(#315)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ui/src/views/log/component/ChatRecordDrawer.vue | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ui/src/views/log/component/ChatRecordDrawer.vue b/ui/src/views/log/component/ChatRecordDrawer.vue index c80933411..7471dd182 100644 --- a/ui/src/views/log/component/ChatRecordDrawer.vue +++ b/ui/src/views/log/component/ChatRecordDrawer.vue @@ -1,7 +1,7 @@