From 4e615db713bc9f5fc11b5b9c1afd0fef423729f5 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Tue, 12 Nov 2024 17:49:22 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BF=AE=E6=94=B9=E6=A0=87=E6=B3=A8?= =?UTF-8?q?=E6=97=B6=E5=8F=AF=E4=BB=A5=E9=80=89=E6=8B=A9=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --story=1016833 --user=王孝刚 对话日志-修改标注时可以选择默认知识库 #229 https://www.tapd.cn/57709429/s/1608846 --- .../serializers/chat_serializers.py | 66 ++++++- apps/application/swagger_api/chat_api.py | 32 ++++ apps/application/urls.py | 4 + apps/application/views/chat_views.py | 26 ++- ui/src/api/log.ts | 20 +- .../views/log/component/EditContentDialog.vue | 23 ++- ui/src/views/log/index.vue | 173 +++++++++++++++++- 7 files changed, 335 insertions(+), 9 deletions(-) diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index d796b418b..40a16af2f 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -40,7 +40,7 @@ from common.util.lock import try_lock, un_lock from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id from dataset.serializers.paragraph_serializers import ParagraphSerializers -from embedding.task import embedding_by_paragraph +from embedding.task import embedding_by_paragraph, embedding_by_paragraph_list from setting.models import Model from setting.models_provider import get_model_credential from smartdoc.conf import PROJECT_DIR @@ -658,3 +658,67 @@ class ChatRecordSerializer(serializers.Serializer): data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id}) o.is_valid(raise_exception=True) return o.delete() + + class PostImprove(serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + chat_ids = serializers.ListSerializer(child=serializers.UUIDField(), required=True, + error_messages=ErrMessage.list("对话id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not Document.objects.filter(id=self.data['document_id'], dataset_id=self.data['dataset_id']).exists(): + raise AppApiException(500, "文档id不正确") + + @staticmethod + def post_embedding_paragraph(paragraph_ids, dataset_id): + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + embedding_by_paragraph_list(paragraph_ids, model_id) + + @post(post_function=post_embedding_paragraph) + @transaction.atomic + def post_improve(self, instance: Dict): + ChatRecordSerializer.PostImprove(data=instance).is_valid(raise_exception=True) + + chat_ids = instance['chat_ids'] + document_id = instance['document_id'] + dataset_id = instance['dataset_id'] + + # 获取所有聊天记录 + chat_record_list = list(ChatRecord.objects.filter(chat_id__in=chat_ids)) + if len(chat_record_list) < len(chat_ids): + raise AppApiException(500, "存在不存在的对话记录") + + # 批量创建段落和问题映射 + paragraphs = [] + paragraph_ids = [] + problem_paragraph_mappings = [] + for chat_record in chat_record_list: + paragraph = Paragraph( + id=uuid.uuid1(), + document_id=document_id, + content=chat_record.answer_text, + dataset_id=dataset_id, + title=chat_record.problem_text + ) + problem, _ = Problem.objects.get_or_create(content=chat_record.problem_text, dataset_id=dataset_id) + problem_paragraph_mapping = ProblemParagraphMapping( + id=uuid.uuid1(), + dataset_id=dataset_id, + document_id=document_id, + problem_id=problem.id, + paragraph_id=paragraph.id + ) + paragraphs.append(paragraph) + paragraph_ids.append(paragraph.id) + problem_paragraph_mappings.append(problem_paragraph_mapping) + chat_record.improve_paragraph_id_list.append(paragraph.id) + + # 批量保存段落和问题映射 + Paragraph.objects.bulk_create(paragraphs) + ProblemParagraphMapping.objects.bulk_create(problem_paragraph_mappings) + + # 批量保存聊天记录 + ChatRecord.objects.bulk_update(chat_record_list, ['improve_paragraph_id_list']) + + return paragraph_ids, dataset_id diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index cc2a50097..0681f11d8 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -267,6 +267,38 @@ class ImproveApi(ApiMixin): } ) + @staticmethod + def get_request_body_api_post(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['dataset_id', 'document_id', 'chat_ids'], + properties={ + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id", + description="知识库id"), + 'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id", + description="文档id"), + 'chat_ids': openapi.Schema(type=openapi.TYPE_ARRAY, title="会话id列表", + description="会话id列表", + items=openapi.Schema(type=openapi.TYPE_STRING)) + + } + ) + + @staticmethod + def get_request_params_api_post(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + + ] + class VoteApi(ApiMixin): @staticmethod diff --git a/apps/application/urls.py b/apps/application/urls.py index bca2725be..b4339ba84 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -61,6 +61,10 @@ urlpatterns = [ 'application//chat//chat_record//dataset//document_id//improve', views.ChatView.ChatRecord.Improve.as_view(), name=''), + path( + 'application//dataset//improve', + views.ChatView.ChatRecord.Improve.as_view(), + name=''), path('application//chat//chat_record//improve', views.ChatView.ChatRecord.ChatRecordImprove.as_view()), path('application/chat_message/', views.ChatView.Message.as_view(), name='application/message'), diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 927670bb2..4b60e4bcb 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -129,7 +129,8 @@ class ChatView(APIView): 'client_id': request.auth.client_id, 'form_data': (request.data.get( 'form_data') if 'form_data' in request.data else {}), - 'image_list': request.data.get('image_list') if 'image_list' in request.data else [], + 'image_list': request.data.get( + 'image_list') if 'image_list' in request.data else [], 'client_type': request.auth.client_type}).chat() @action(methods=['GET'], detail=False) @@ -364,6 +365,28 @@ class ChatView(APIView): data={'chat_id': chat_id, 'chat_record_id': chat_record_id, 'dataset_id': dataset_id, 'document_id': document_id}).improve(request.data)) + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="添加至知识库", + operation_id="添加至知识库", + manual_parameters=ImproveApi.get_request_params_api_post(), + request_body=ImproveApi.get_request_body_api_post(), + tags=["应用/对话日志/添加至知识库"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + + ), ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.DATASET, + operate=Operate.MANAGE, + dynamic_tag=keywords.get( + 'dataset_id'))], + compare=CompareConstants.AND + ), compare=CompareConstants.AND) + def post(self, request: Request, application_id: str, dataset_id: str): + return result.success(ChatRecordSerializer.PostImprove().post_improve(request.data)) + class Operate(APIView): authentication_classes = [TokenAuth] @@ -417,4 +440,3 @@ class ChatView(APIView): file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() file_ids.append({'name': file.name, 'url': file_url, 'file_id': file_url.split('/')[-1]}) return result.success(file_ids) - diff --git a/ui/src/api/log.ts b/ui/src/api/log.ts index 4222d4fc0..20973a371 100644 --- a/ui/src/api/log.ts +++ b/ui/src/api/log.ts @@ -1,5 +1,5 @@ import { Result } from '@/request/Result' -import { get, del, put, exportExcel, exportExcelPost } from '@/request/index' +import { get, del, put, exportExcel, exportExcelPost, post } from '@/request/index' import type { pageRequest } from '@/api/type/common' import { type Ref } from 'vue' @@ -114,7 +114,22 @@ const putChatRecordLog: ( loading ) } +/** + * 对话记录提交至知识库 + * @param data + * @param loading + * @param application_id + * @param dataset_id + */ +const postChatRecordLog: ( + application_id: string, + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (application_id, dataset_id, data, loading) => { + return post(`${prefix}/${application_id}/dataset/${dataset_id}/improve`, data, undefined, loading) +} /** * 获取标注段落列表信息 * @param 参数 @@ -215,5 +230,6 @@ export default { delMarkRecord, exportChatLog, getChatLogClient, - delChatClientLog + delChatClientLog, + postChatRecordLog } diff --git a/ui/src/views/log/component/EditContentDialog.vue b/ui/src/views/log/component/EditContentDialog.vue index 17731a094..d97316cd9 100644 --- a/ui/src/views/log/component/EditContentDialog.vue +++ b/ui/src/views/log/component/EditContentDialog.vue @@ -85,6 +85,7 @@ filterable placeholder="请选择文档" :loading="optionLoading" + @change="changeDocument" > { } function changeDataset(id: string) { + if (user.userInfo) { + localStorage.setItem(user.userInfo.id + 'chat_dataset_id', id) + } form.value.document_id = '' getDocument(id) } +function changeDocument(id: string) { + if (user.userInfo) { + localStorage.setItem(user.userInfo.id + 'chat_document_id', id) + } +} + function getDocument(id: string) { document.asyncGetAllDocument(id, loading).then((res: any) => { documentList.value = res.data @@ -229,11 +239,22 @@ function getDocument(id: string) { function getDataset() { application.asyncGetApplicationDataset(id, loading).then((res: any) => { datasetList.value = res.data + if (localStorage.getItem(user.userInfo?.id + 'chat_dataset_id')) { + form.value.dataset_id = localStorage.getItem(user.userInfo?.id + 'chat_dataset_id') as string + if (!datasetList.value.find((v) => v.id === form.value.dataset_id)) { + form.value.dataset_id = '' + } else { + getDocument(form.value.dataset_id) + } + } }) } const open = (data: any) => { getDataset() + if (localStorage.getItem(user.userInfo?.id + 'chat_document_id')) { + form.value.document_id = localStorage.getItem(user.userInfo?.id + 'chat_document_id') as string + } form.value.chat_id = data.chat_id form.value.record_id = data.id form.value.problem_text = data.problem_text ? data.problem_text.substring(0, 256) : '' diff --git a/ui/src/views/log/index.vue b/ui/src/views/log/index.vue index b7b21b97d..9df70acb6 100644 --- a/ui/src/views/log/index.vue +++ b/ui/src/views/log/index.vue @@ -30,8 +30,11 @@ clearable />
- 清除策略 + 清除策略 导出 + 添加至知识库
@@ -177,6 +180,86 @@ + + + + + + + + + + + + + + {{ item.name }} + + + + + + + + {{ item.name }} + + + + + +