diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py index be6d9fe39..eb63a99be 100644 --- a/apps/application/serializers/application.py +++ b/apps/application/serializers/application.py @@ -245,8 +245,9 @@ class ApplicationCreateSerializer(serializers.Serializer): 'knowledge_id_list': self.data.get('knowledge_id_list')}).is_valid() @staticmethod - def to_application_model(user_id: str, application: Dict): + def to_application_model(user_id: str, workspace_id: str, application: Dict): return Application(id=uuid.uuid7(), name=application.get('name'), desc=application.get('desc'), + workspace_id=workspace_id, prologue=application.get('prologue'), dialogue_number=application.get('dialogue_number', 0), user_id=user_id, model_id=application.get('model_id'), @@ -321,7 +322,7 @@ class Query(serializers.Serializer): 'folder_query_set': folder_query_set, 'application_query_set': application_query_set, 'application_custom_sql': application_custom_sql_query_set - } if (workspace_manage and is_x_pack_ee) else {'folder_query_set': folder_query_set, + } if (workspace_manage and not is_x_pack_ee) else {'folder_query_set': folder_query_set, 'application_query_set': application_query_set, 'user_query_set': QuerySet( workspace_user_role_mapping_model).filter( @@ -442,8 +443,10 @@ class ApplicationSerializer(serializers.Serializer): def insert_simple(self, instance: Dict): self.is_valid(raise_exception=True) user_id = self.data.get('user_id') + workspace_id = self.data.get("workspace_id") ApplicationCreateSerializer.SimplateRequest(data=instance).is_valid(user_id=user_id, raise_exception=True) - application_model = ApplicationCreateSerializer.SimplateRequest.to_application_model(user_id, instance) + application_model = ApplicationCreateSerializer.SimplateRequest.to_application_model(user_id, workspace_id, + instance) dataset_id_list = instance.get('knowledge_id_list', []) application_knowledge_mapping_model_list = [ self.to_application_knowledge_mapping(application_model.id, dataset_id) for diff --git a/apps/chat/api/vote_api.py b/apps/chat/api/vote_api.py new file mode 100644 index 000000000..0066678f1 --- /dev/null +++ b/apps/chat/api/vote_api.py @@ -0,0 +1,43 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: vote_api.py + @date:2025/6/23 17:35 + @desc: +""" +from django.utils.translation import gettext_lazy as _ +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter + +from chat.serializers.chat_record import VoteRequest +from common.mixins.api_mixin import APIMixin +from common.result import DefaultResultSerializer + + +class VoteAPI(APIMixin): + @staticmethod + def get_request(): + return VoteRequest + + @staticmethod + def get_parameters(): + return [OpenApiParameter( + name="chat_id", + description=_("Chat ID"), + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="chat_record_id", + description=_("Chat Record ID"), + type=OpenApiTypes.STR, + location='path', + required=True, + ) + ] + + @staticmethod + def get_response(): + return DefaultResultSerializer diff --git a/apps/chat/serializers/chat_record.py b/apps/chat/serializers/chat_record.py new file mode 100644 index 000000000..8ff992d02 --- /dev/null +++ b/apps/chat/serializers/chat_record.py @@ -0,0 +1,65 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: chat_record.py + @date:2025/6/23 11:16 + @desc: +""" +from typing import Dict + +from django.db import transaction +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _, gettext +from rest_framework import serializers + +from application.models import VoteChoices, ChatRecord +from common.exception.app_exception import AppApiException +from common.utils.lock import try_lock, un_lock + + +class VoteRequest(serializers.Serializer): + vote_status = serializers.ChoiceField(choices=VoteChoices.choices, + label=_("Bidding Status")) + + +class VoteSerializer(serializers.Serializer): + chat_id = serializers.UUIDField(required=True, label=_("Conversation ID")) + + chat_record_id = serializers.UUIDField(required=True, + label=_("Conversation record id")) + + @transaction.atomic + def vote(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + VoteRequest(data=instance).is_valid(raise_exception=True) + if not try_lock(self.data.get('chat_record_id')): + raise AppApiException(500, + gettext( + "Voting on the current session minutes, please do not send repeated requests")) + try: + chat_record_details_model = QuerySet(ChatRecord).get(id=self.data.get('chat_record_id'), + chat_id=self.data.get('chat_id')) + if chat_record_details_model is None: + raise AppApiException(500, gettext("Non-existent conversation chat_record_id")) + vote_status = instance.get("vote_status") + if chat_record_details_model.vote_status == VoteChoices.UN_VOTE: + if vote_status == VoteChoices.STAR: + # 点赞 + chat_record_details_model.vote_status = VoteChoices.STAR + + if vote_status == VoteChoices.TRAMPLE: + # 点踩 + chat_record_details_model.vote_status = VoteChoices.TRAMPLE + chat_record_details_model.save() + else: + if vote_status == VoteChoices.UN_VOTE: + # 取消点赞 + chat_record_details_model.vote_status = VoteChoices.UN_VOTE + chat_record_details_model.save() + else: + raise AppApiException(500, gettext("Already voted, please cancel first and then vote again")) + finally: + un_lock(self.data.get('chat_record_id')) + return True diff --git a/apps/chat/urls.py b/apps/chat/urls.py index b2094f595..812a25b60 100644 --- a/apps/chat/urls.py +++ b/apps/chat/urls.py @@ -12,4 +12,5 @@ urlpatterns = [ path('chat_message/', views.ChatView.as_view()), path('open', views.OpenView.as_view()), path('captcha', views.CaptchaView.as_view(), name='captcha'), + path('vote/chat//chat_record/', views.VoteView.as_view(), name='vote'), ] diff --git a/apps/chat/views/__init__.py b/apps/chat/views/__init__.py index 9a1e1c14f..fa38335c9 100644 --- a/apps/chat/views/__init__.py +++ b/apps/chat/views/__init__.py @@ -8,3 +8,4 @@ """ from .chat_embed import * from .chat import * +from .chat_record import * diff --git a/apps/chat/views/chat.py b/apps/chat/views/chat.py index 0d4dbe648..970f27a41 100644 --- a/apps/chat/views/chat.py +++ b/apps/chat/views/chat.py @@ -128,9 +128,9 @@ class OpenView(APIView): class CaptchaView(APIView): @extend_schema(methods=['GET'], - summary=_("Get captcha"), - description=_("Get captcha"), - operation_id=_("Get captcha"), # type: ignore + summary=_("Get Chat captcha"), + description=_("Get Chat captcha"), + operation_id=_("Get Chat captcha"), # type: ignore tags=[_("User Management")], # type: ignore responses=CaptchaAPI.get_response()) def get(self, request: Request): diff --git a/apps/chat/views/chat_record.py b/apps/chat/views/chat_record.py new file mode 100644 index 000000000..9d1cabd02 --- /dev/null +++ b/apps/chat/views/chat_record.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: chat_record.py + @date:2025/6/23 10:42 + @desc: +""" +from drf_spectacular.utils import extend_schema +from rest_framework.request import Request +from rest_framework.views import APIView +from django.utils.translation import gettext_lazy as _ + +from chat.api.vote_api import VoteAPI +from chat.serializers.chat_record import VoteSerializer +from common import result +from common.auth import TokenAuth + + +class VoteView(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['PUT'], + description=_("Like, Dislike"), + summary=_("Like, Dislike"), + operation_id=_("Like, Dislike"), # type: ignore + parameters=VoteAPI.get_parameters(), + request=VoteAPI.get_request(), + responses=VoteAPI.get_response(), + tags=[_('Chat')] # type: ignore + ) + def put(self, request: Request, chat_id: str, chat_record_id: str): + return result.success(VoteSerializer( + data={'chat_id': chat_id, + 'chat_record_id': chat_record_id + }).vote(request.data)) diff --git a/apps/knowledge/serializers/knowledge.py b/apps/knowledge/serializers/knowledge.py index 50044463a..4c3c03e58 100644 --- a/apps/knowledge/serializers/knowledge.py +++ b/apps/knowledge/serializers/knowledge.py @@ -196,9 +196,9 @@ class KnowledgeSerializer(serializers.Serializer): if not root: raise serializers.ValidationError(_('Folder not found')) workspace_manage = is_workspace_manage(self.data.get('user_id'), self.data.get('workspace_id')) - + is_x_pack_ee = self.is_x_pack_ee() return native_search( - self.get_query_set(), + self.get_query_set(workspace_manage, is_x_pack_ee), select_string=get_file_content( os.path.join( PROJECT_DIR, diff --git a/ui/src/api/chat/chat.ts b/ui/src/api/chat/chat.ts index a7d4bf77f..4f1e630b5 100644 --- a/ui/src/api/chat/chat.ts +++ b/ui/src/api/chat/chat.ts @@ -159,6 +159,29 @@ const getAuthSetting: (auth_type: string, loading?: Ref) => Promise { return get(`/chat_user/${auth_type}/detail`, undefined, loading) } +/** + * 点赞点踩 + * @param chat_id 对话id + * @param chat_record_id 对话记录id + * @param vote_status 点赞状态 + * @param loading 加载器 + * @returns + */ +const vote: ( + chat_id: string, + chat_record_id: string, + vote_status: string, + loading?: Ref, +) => Promise> = (chat_id, chat_record_id, vote_status, loading) => { + return put( + `/vote/chat/${chat_id}/chat_record/${chat_record_id}`, + { + vote_status, + }, + undefined, + loading, + ) +} export default { open, chat, @@ -176,4 +199,5 @@ export default { ldapLogin, getAuthSetting, passwordAuthentication, + vote, } diff --git a/ui/src/components/ai-chat/component/operation-button/ChatOperationButton.vue b/ui/src/components/ai-chat/component/operation-button/ChatOperationButton.vue index 025679648..e5bb53d0f 100644 --- a/ui/src/components/ai-chat/component/operation-button/ChatOperationButton.vue +++ b/ui/src/components/ai-chat/component/operation-button/ChatOperationButton.vue @@ -103,6 +103,7 @@ import { nextTick, onMounted, ref, onBeforeUnmount } from 'vue' import { useRoute } from 'vue-router' import { copyClick } from '@/utils/clipboard' import applicationApi from '@/api/application/application' +import chatAPI from '@/api/chat/chat' import { datetimeFormat } from '@/utils/time' import { MsgError } from '@/utils/message' import bus from '@/bus' @@ -118,7 +119,7 @@ const copy = (data: any) => { } const route = useRoute() const { - params: { id } + params: { id }, } = route as any const props = withDefaults( @@ -134,8 +135,8 @@ const props = withDefaults( }>(), { data: () => ({}), - type: 'ai-chat' - } + type: 'ai-chat', + }, ) const emit = defineEmits(['update:data', 'regeneration']) @@ -152,12 +153,10 @@ function regeneration() { } function voteHandle(val: string) { - applicationApi - .putChatVote(props.applicationId, props.chatId, props.data.record_id, val, loading) - .then(() => { - buttonData.value['vote_status'] = val - emit('update:data', buttonData.value) - }) + chatAPI.vote(props.chatId, props.data.record_id, val, loading).then(() => { + buttonData.value['vote_status'] = val + emit('update:data', buttonData.value) + }) } function markdownToPlainText(md: string) { @@ -203,9 +202,9 @@ function smartSplit( 0: 10, 1: 25, 3: 50, - 5: 100 + 5: 100, }, - is_end = false + is_end = false, ) { // 匹配中文逗号/句号,且后面至少还有20个字符(含任何字符,包括换行) const regex = /([。?\n])|(]*><\/audio>)/g @@ -261,7 +260,7 @@ enum AudioStatus { /** * 错误 */ - ERROR = 'ERROR' + ERROR = 'ERROR', } class AudioManage { textList: Array @@ -318,7 +317,7 @@ class AudioManage { .postTextToSpeech( (props.applicationId as string) || (id as string), { text: text }, - loading + loading, ) .then(async (res: any) => { if (res.type === 'application/json') { @@ -347,7 +346,7 @@ class AudioManage { this.audioList.push(audioElement) } else { const speechSynthesisUtterance: SpeechSynthesisUtterance = new SpeechSynthesisUtterance( - text + text, ) speechSynthesisUtterance.onend = () => { this.statusList[index] = AudioStatus.END @@ -381,7 +380,7 @@ class AudioManage { .postTextToSpeech( (props.applicationId as string) || (id as string), { text: text }, - loading + loading, ) .then(async (res: any) => { if (res.type === 'application/json') { @@ -432,7 +431,7 @@ class AudioManage { // 需要播放的内容 const index = this.statusList.findIndex((status) => - [AudioStatus.MOUNTED, AudioStatus.READY].includes(status) + [AudioStatus.MOUNTED, AudioStatus.READY].includes(status), ) if (index < 0 || this.statusList[index] === AudioStatus.MOUNTED) { return @@ -502,9 +501,9 @@ class AudioManage { { 0: 20, 1: 50, - 5: 100 + 5: 100, }, - is_end + is_end, ) return split