feat: Chat vote (#3355)

This commit is contained in:
shaohuzhang1 2025-06-23 20:19:32 +08:00 committed by GitHub
parent 01ed7045e0
commit dca48d1388
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 199 additions and 26 deletions

View File

@ -245,8 +245,9 @@ class ApplicationCreateSerializer(serializers.Serializer):
'knowledge_id_list': self.data.get('knowledge_id_list')}).is_valid()
@staticmethod
def to_application_model(user_id: str, application: Dict):
def to_application_model(user_id: str, workspace_id: str, application: Dict):
return Application(id=uuid.uuid7(), name=application.get('name'), desc=application.get('desc'),
workspace_id=workspace_id,
prologue=application.get('prologue'),
dialogue_number=application.get('dialogue_number', 0),
user_id=user_id, model_id=application.get('model_id'),
@ -321,7 +322,7 @@ class Query(serializers.Serializer):
'folder_query_set': folder_query_set,
'application_query_set': application_query_set,
'application_custom_sql': application_custom_sql_query_set
} if (workspace_manage and is_x_pack_ee) else {'folder_query_set': folder_query_set,
} if (workspace_manage and not is_x_pack_ee) else {'folder_query_set': folder_query_set,
'application_query_set': application_query_set,
'user_query_set': QuerySet(
workspace_user_role_mapping_model).filter(
@ -442,8 +443,10 @@ class ApplicationSerializer(serializers.Serializer):
def insert_simple(self, instance: Dict):
self.is_valid(raise_exception=True)
user_id = self.data.get('user_id')
workspace_id = self.data.get("workspace_id")
ApplicationCreateSerializer.SimplateRequest(data=instance).is_valid(user_id=user_id, raise_exception=True)
application_model = ApplicationCreateSerializer.SimplateRequest.to_application_model(user_id, instance)
application_model = ApplicationCreateSerializer.SimplateRequest.to_application_model(user_id, workspace_id,
instance)
dataset_id_list = instance.get('knowledge_id_list', [])
application_knowledge_mapping_model_list = [
self.to_application_knowledge_mapping(application_model.id, dataset_id) for

43
apps/chat/api/vote_api.py Normal file
View File

@ -0,0 +1,43 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file vote_api.py
@date2025/6/23 17:35
@desc:
"""
from django.utils.translation import gettext_lazy as _
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter
from chat.serializers.chat_record import VoteRequest
from common.mixins.api_mixin import APIMixin
from common.result import DefaultResultSerializer
class VoteAPI(APIMixin):
@staticmethod
def get_request():
return VoteRequest
@staticmethod
def get_parameters():
return [OpenApiParameter(
name="chat_id",
description=_("Chat ID"),
type=OpenApiTypes.STR,
location='path',
required=True,
),
OpenApiParameter(
name="chat_record_id",
description=_("Chat Record ID"),
type=OpenApiTypes.STR,
location='path',
required=True,
)
]
@staticmethod
def get_response():
return DefaultResultSerializer

View File

@ -0,0 +1,65 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file chat_record.py
@date2025/6/23 11:16
@desc:
"""
from typing import Dict
from django.db import transaction
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _, gettext
from rest_framework import serializers
from application.models import VoteChoices, ChatRecord
from common.exception.app_exception import AppApiException
from common.utils.lock import try_lock, un_lock
class VoteRequest(serializers.Serializer):
vote_status = serializers.ChoiceField(choices=VoteChoices.choices,
label=_("Bidding Status"))
class VoteSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
chat_record_id = serializers.UUIDField(required=True,
label=_("Conversation record id"))
@transaction.atomic
def vote(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
VoteRequest(data=instance).is_valid(raise_exception=True)
if not try_lock(self.data.get('chat_record_id')):
raise AppApiException(500,
gettext(
"Voting on the current session minutes, please do not send repeated requests"))
try:
chat_record_details_model = QuerySet(ChatRecord).get(id=self.data.get('chat_record_id'),
chat_id=self.data.get('chat_id'))
if chat_record_details_model is None:
raise AppApiException(500, gettext("Non-existent conversation chat_record_id"))
vote_status = instance.get("vote_status")
if chat_record_details_model.vote_status == VoteChoices.UN_VOTE:
if vote_status == VoteChoices.STAR:
# 点赞
chat_record_details_model.vote_status = VoteChoices.STAR
if vote_status == VoteChoices.TRAMPLE:
# 点踩
chat_record_details_model.vote_status = VoteChoices.TRAMPLE
chat_record_details_model.save()
else:
if vote_status == VoteChoices.UN_VOTE:
# 取消点赞
chat_record_details_model.vote_status = VoteChoices.UN_VOTE
chat_record_details_model.save()
else:
raise AppApiException(500, gettext("Already voted, please cancel first and then vote again"))
finally:
un_lock(self.data.get('chat_record_id'))
return True

View File

@ -12,4 +12,5 @@ urlpatterns = [
path('chat_message/<str:chat_id>', views.ChatView.as_view()),
path('open', views.OpenView.as_view()),
path('captcha', views.CaptchaView.as_view(), name='captcha'),
path('vote/chat/<str:chat_id>/chat_record/<str:chat_record_id>', views.VoteView.as_view(), name='vote'),
]

View File

@ -8,3 +8,4 @@
"""
from .chat_embed import *
from .chat import *
from .chat_record import *

View File

@ -128,9 +128,9 @@ class OpenView(APIView):
class CaptchaView(APIView):
@extend_schema(methods=['GET'],
summary=_("Get captcha"),
description=_("Get captcha"),
operation_id=_("Get captcha"), # type: ignore
summary=_("Get Chat captcha"),
description=_("Get Chat captcha"),
operation_id=_("Get Chat captcha"), # type: ignore
tags=[_("User Management")], # type: ignore
responses=CaptchaAPI.get_response())
def get(self, request: Request):

View File

@ -0,0 +1,37 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file chat_record.py
@date2025/6/23 10:42
@desc:
"""
from drf_spectacular.utils import extend_schema
from rest_framework.request import Request
from rest_framework.views import APIView
from django.utils.translation import gettext_lazy as _
from chat.api.vote_api import VoteAPI
from chat.serializers.chat_record import VoteSerializer
from common import result
from common.auth import TokenAuth
class VoteView(APIView):
authentication_classes = [TokenAuth]
@extend_schema(
methods=['PUT'],
description=_("Like, Dislike"),
summary=_("Like, Dislike"),
operation_id=_("Like, Dislike"), # type: ignore
parameters=VoteAPI.get_parameters(),
request=VoteAPI.get_request(),
responses=VoteAPI.get_response(),
tags=[_('Chat')] # type: ignore
)
def put(self, request: Request, chat_id: str, chat_record_id: str):
return result.success(VoteSerializer(
data={'chat_id': chat_id,
'chat_record_id': chat_record_id
}).vote(request.data))

View File

@ -196,9 +196,9 @@ class KnowledgeSerializer(serializers.Serializer):
if not root:
raise serializers.ValidationError(_('Folder not found'))
workspace_manage = is_workspace_manage(self.data.get('user_id'), self.data.get('workspace_id'))
is_x_pack_ee = self.is_x_pack_ee()
return native_search(
self.get_query_set(),
self.get_query_set(workspace_manage, is_x_pack_ee),
select_string=get_file_content(
os.path.join(
PROJECT_DIR,

View File

@ -159,6 +159,29 @@ const getAuthSetting: (auth_type: string, loading?: Ref<boolean>) => Promise<Res
) => {
return get(`/chat_user/${auth_type}/detail`, undefined, loading)
}
/**
*
* @param chat_id id
* @param chat_record_id id
* @param vote_status
* @param loading
* @returns
*/
const vote: (
chat_id: string,
chat_record_id: string,
vote_status: string,
loading?: Ref<boolean>,
) => Promise<Result<boolean>> = (chat_id, chat_record_id, vote_status, loading) => {
return put(
`/vote/chat/${chat_id}/chat_record/${chat_record_id}`,
{
vote_status,
},
undefined,
loading,
)
}
export default {
open,
chat,
@ -176,4 +199,5 @@ export default {
ldapLogin,
getAuthSetting,
passwordAuthentication,
vote,
}

View File

@ -103,6 +103,7 @@ import { nextTick, onMounted, ref, onBeforeUnmount } from 'vue'
import { useRoute } from 'vue-router'
import { copyClick } from '@/utils/clipboard'
import applicationApi from '@/api/application/application'
import chatAPI from '@/api/chat/chat'
import { datetimeFormat } from '@/utils/time'
import { MsgError } from '@/utils/message'
import bus from '@/bus'
@ -118,7 +119,7 @@ const copy = (data: any) => {
}
const route = useRoute()
const {
params: { id }
params: { id },
} = route as any
const props = withDefaults(
@ -134,8 +135,8 @@ const props = withDefaults(
}>(),
{
data: () => ({}),
type: 'ai-chat'
}
type: 'ai-chat',
},
)
const emit = defineEmits(['update:data', 'regeneration'])
@ -152,12 +153,10 @@ function regeneration() {
}
function voteHandle(val: string) {
applicationApi
.putChatVote(props.applicationId, props.chatId, props.data.record_id, val, loading)
.then(() => {
buttonData.value['vote_status'] = val
emit('update:data', buttonData.value)
})
chatAPI.vote(props.chatId, props.data.record_id, val, loading).then(() => {
buttonData.value['vote_status'] = val
emit('update:data', buttonData.value)
})
}
function markdownToPlainText(md: string) {
@ -203,9 +202,9 @@ function smartSplit(
0: 10,
1: 25,
3: 50,
5: 100
5: 100,
},
is_end = false
is_end = false,
) {
// /20
const regex = /([。?\n])|(<audio[^>]*><\/audio>)/g
@ -261,7 +260,7 @@ enum AudioStatus {
/**
* 错误
*/
ERROR = 'ERROR'
ERROR = 'ERROR',
}
class AudioManage {
textList: Array<string>
@ -318,7 +317,7 @@ class AudioManage {
.postTextToSpeech(
(props.applicationId as string) || (id as string),
{ text: text },
loading
loading,
)
.then(async (res: any) => {
if (res.type === 'application/json') {
@ -347,7 +346,7 @@ class AudioManage {
this.audioList.push(audioElement)
} else {
const speechSynthesisUtterance: SpeechSynthesisUtterance = new SpeechSynthesisUtterance(
text
text,
)
speechSynthesisUtterance.onend = () => {
this.statusList[index] = AudioStatus.END
@ -381,7 +380,7 @@ class AudioManage {
.postTextToSpeech(
(props.applicationId as string) || (id as string),
{ text: text },
loading
loading,
)
.then(async (res: any) => {
if (res.type === 'application/json') {
@ -432,7 +431,7 @@ class AudioManage {
//
const index = this.statusList.findIndex((status) =>
[AudioStatus.MOUNTED, AudioStatus.READY].includes(status)
[AudioStatus.MOUNTED, AudioStatus.READY].includes(status),
)
if (index < 0 || this.statusList[index] === AudioStatus.MOUNTED) {
return
@ -502,9 +501,9 @@ class AudioManage {
{
0: 20,
1: 50,
5: 100
5: 100,
},
is_end
is_end,
)
return split