diff --git a/apps/application/serializers/application.py b/apps/application/serializers/application.py index 2e4bc033e..3f60a5b12 100644 --- a/apps/application/serializers/application.py +++ b/apps/application/serializers/application.py @@ -704,14 +704,14 @@ class ApplicationOperateSerializer(serializers.Serializer): self.is_valid() application_id = self.data.get("application_id") application = QuerySet(Application).get(id=application_id) - dataset_list = self.list_knowledge(with_valid=False) + knowledge_list = self.list_knowledge(with_valid=False) mapping_knowledge_id_list = [akm.knowledge_id for akm in QuerySet(ApplicationKnowledgeMapping).filter(application_id=application_id)] knowledge_id_list = [d.get('id') for d in list(filter(lambda row: mapping_knowledge_id_list.__contains__(row.get('id')), - dataset_list))] + knowledge_list))] return {**ApplicationSerializerModel(application).data, - 'dataset_id_list': knowledge_id_list} + 'knowledge_id_list': knowledge_id_list} def list_knowledge(self, with_valid=True): if with_valid: diff --git a/apps/application/serializers/application_chat_record.py b/apps/application/serializers/application_chat_record.py index 32c68a2f2..d411676c0 100644 --- a/apps/application/serializers/application_chat_record.py +++ b/apps/application/serializers/application_chat_record.py @@ -15,7 +15,8 @@ from django.db.models import QuerySet from rest_framework import serializers from rest_framework.utils.formatting import lazy_format -from application.models import ChatRecord +from application.models import ChatRecord, ApplicationAccessToken +from application.serializers.common import ChatInfo from common.db.search import page_search from common.exception.app_exception import AppApiException from common.utils.common import post @@ -37,6 +38,40 @@ class ChatRecordSerializerModel(serializers.ModelSerializer): 'create_time', 'update_time'] +class ChatRecordOperateSerializer(serializers.Serializer): + chat_id = serializers.UUIDField(required=True, label=_("Conversation ID")) + workspace_id = serializers.CharField(required=False, label=_("Workspace ID")) + application_id = serializers.UUIDField(required=True, label=_("Application ID")) + chat_record_id = serializers.UUIDField(required=True, label=_("Conversation record id")) + + def is_valid(self, *, debug=False, raise_exception=False): + super().is_valid(raise_exception=True) + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=self.data.get('application_id')).first() + if application_access_token is None: + raise AppApiException(500, gettext('Application authentication information does not exist')) + if not application_access_token.show_source and not debug: + raise AppApiException(500, gettext('Displaying knowledge sources is not enabled')) + + def get_chat_record(self): + chat_record_id = self.data.get('chat_record_id') + chat_id = self.data.get('chat_id') + chat_info: ChatInfo = ChatInfo.get_cache(chat_id) + if chat_info is not None: + chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if + str(chat_record.id) == str(chat_record_id)] + if chat_record_list is not None and len(chat_record_list): + return chat_record_list[-1] + return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() + + def one(self, debug): + self.is_valid(debug=debug, raise_exception=True) + chat_record = self.get_chat_record() + if chat_record is None: + raise AppApiException(500, gettext("Conversation does not exist")) + return ApplicationChatRecordQuerySerializers.reset_chat_record(chat_record) + + class ApplicationChatRecordQuerySerializers(serializers.Serializer): application_id = serializers.UUIDField(required=True, label=_("Application ID")) chat_id = serializers.UUIDField(required=True, label=_("Chat ID")) diff --git a/apps/application/urls.py b/apps/application/urls.py index 4cf1fc7f0..0c39febb8 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -34,6 +34,9 @@ urlpatterns = [ path( 'workspace//application//chat//chat_record', views.ApplicationChatRecord.as_view()), + path( + 'workspace//application//chat//chat_record/', + views.ApplicationChatRecordOperateAPI.as_view()), path( 'workspace//application//chat//chat_record//', views.ApplicationChatRecord.Page.as_view()), diff --git a/apps/application/views/application_chat_record.py b/apps/application/views/application_chat_record.py index dd23208b5..fef9047e9 100644 --- a/apps/application/views/application_chat_record.py +++ b/apps/application/views/application_chat_record.py @@ -14,7 +14,8 @@ from rest_framework.views import APIView from application.api.application_chat_record import ApplicationChatRecordQueryAPI, \ ApplicationChatRecordImproveParagraphAPI, ApplicationChatRecordAddKnowledgeAPI from application.serializers.application_chat_record import ApplicationChatRecordQuerySerializers, \ - ApplicationChatRecordImproveSerializer, ChatRecordImproveSerializer, ApplicationChatRecordAddKnowledgeSerializer + ApplicationChatRecordImproveSerializer, ChatRecordImproveSerializer, ApplicationChatRecordAddKnowledgeSerializer, \ + ChatRecordOperateSerializer from common import result from common.auth import TokenAuth from common.auth.authentication import has_permissions @@ -67,6 +68,30 @@ class ApplicationChatRecord(APIView): page_size=page_size)) +class ApplicationChatRecordOperateAPI(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['GET'], + description=_("Get conversation record details"), + summary=_("Get conversation record details"), + operation_id=_("Get conversation record details"), # type: ignore + request=ApplicationChatRecordQueryAPI.get_request(), + parameters=ApplicationChatRecordQueryAPI.get_parameters(), + responses=ApplicationChatRecordQueryAPI.get_response(), + tags=[_("Application/Conversation Log")] # type: ignore + ) + @has_permissions(PermissionConstants.APPLICATION_CHAT_LOG.get_workspace_application_permission(), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role()) + def get(self, request: Request, workspace_id: str, application_id: str, chat_id: str, chat_record_id: str): + return result.success(ChatRecordOperateSerializer( + data={ + 'workspace_id': workspace_id, + 'application_id': application_id, + 'chat_id': chat_id, + 'chat_record_id': chat_record_id}).one(True)) + + class ApplicationChatRecordAddKnowledge(APIView): authentication_classes = [TokenAuth] diff --git a/apps/chat/serializers/chat.py b/apps/chat/serializers/chat.py index b47ba160e..9ae5ce10b 100644 --- a/apps/chat/serializers/chat.py +++ b/apps/chat/serializers/chat.py @@ -101,6 +101,8 @@ class DebugChatSerializers(serializers.Serializer): self.is_valid(raise_exception=True) chat_id = self.data.get('chat_id') chat_info: ChatInfo = ChatInfo.get_cache(chat_id) + application = QuerySet(Application).filter(id=chat_info.application_id).first() + chat_info.application = application return ChatSerializers(data={ 'chat_id': chat_id, "chat_user_id": chat_info.chat_user_id, "chat_user_type": chat_info.chat_user_type, diff --git a/ui/src/api/application/application.ts b/ui/src/api/application/application.ts index ab66596bd..633b79fa3 100644 --- a/ui/src/api/application/application.ts +++ b/ui/src/api/application/application.ts @@ -140,6 +140,27 @@ const getStatistics: ( ) => Promise> = (application_id, data, loading) => { return get(`${prefix}/${application_id}/application_stats`, data, loading) } +/** + * 打开调试对话id + * @param application_id 应用id + * @param loading 加载器 + * @returns + */ +const open: (application_id: string, loading?: Ref) => Promise> = ( + application_id, + loading, +) => { + return get(`${prefix}/${application_id}/open`, {}, loading) +} +/** + * 对话 + * @param 参数 + * chat_id: string + * data + */ +const chat: (chat_id: string, data: any) => Promise = (chat_id, data) => { + return postStream(`/api/chat_message/${chat_id}`, data) +} export default { getAllApplication, @@ -153,4 +174,6 @@ export default { exportApplication, importApplication, getStatistics, + open, + chat, } diff --git a/ui/src/api/application/chat-log.ts b/ui/src/api/application/chat-log.ts index 2327e664b..a626baff8 100644 --- a/ui/src/api/application/chat-log.ts +++ b/ui/src/api/application/chat-log.ts @@ -188,7 +188,18 @@ const postExportChatLog: ( loading, ) } - +const getChatRecordDetails: ( + application_id: string, + chat_id: string, + chat_record_id: string, + loading?: Ref, +) => Promise = (application_id, chat_id, chat_record_id, loading) => { + return get( + `${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}`, + {}, + loading, + ) +} export default { postChatLogAddKnowledge, getChatLog, @@ -197,4 +208,5 @@ export default { putChatRecordLog, delMarkChatRecord, postExportChatLog, + getChatRecordDetails, } diff --git a/ui/src/api/model/model.ts b/ui/src/api/model/model.ts index b30e36758..b17c1a175 100644 --- a/ui/src/api/model/model.ts +++ b/ui/src/api/model/model.ts @@ -1,13 +1,13 @@ -import {Result} from '@/request/Result' -import {get, post, del, put} from '@/request/index' -import {type Ref} from 'vue' +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' +import { type Ref } from 'vue' import type { ListModelRequest, Model, CreateModelRequest, EditModelRequest, } from '@/api/type/model' -import type {FormField} from '@/components/dynamics-form/type' +import type { FormField } from '@/components/dynamics-form/type' const prefix = '/workspace/' + localStorage.getItem('workspace_id') @@ -21,7 +21,61 @@ const getModel: ( ) => Promise>> = (data, loading) => { return get(`${prefix}/model`, data, loading) } +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationRerankerModel: ( + application_id: string, + loading?: Ref, +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/model`, { model_type: 'RERANKER' }, loading) +} +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationSTTModel: ( + application_id: string, + loading?: Ref, +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/model`, { model_type: 'STT' }, loading) +} + +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationTTSModel: ( + application_id: string, + loading?: Ref, +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/model`, { model_type: 'TTS' }, loading) +} + +const getApplicationImageModel: ( + application_id: string, + loading?: Ref, +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/model`, { model_type: 'IMAGE' }, loading) +} + +const getApplicationTTIModel: ( + application_id: string, + loading?: Ref, +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/model`, { model_type: 'TTI' }, loading) +} /** * 获取模型参数表单 * @param model_id 模型id @@ -82,10 +136,10 @@ const updateModelParamsForm: ( * @param loading 加载器 * @returns */ -const getModelById: ( - model_id: string, - loading?: Ref, -) => Promise> = (model_id, loading) => { +const getModelById: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading, +) => { return get(`${prefix}/model/${model_id}`, {}, loading) } /** @@ -94,10 +148,10 @@ const getModelById: ( * @param loading 加载器 * @returns */ -const getModelMetaById: ( - model_id: string, - loading?: Ref, -) => Promise> = (model_id, loading) => { +const getModelMetaById: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading, +) => { return get(`${prefix}/model/${model_id}/meta`, {}, loading) } /** @@ -106,16 +160,16 @@ const getModelMetaById: ( * @param loading 加载器 * @returns */ -const pauseDownload: ( - model_id: string, - loading?: Ref, -) => Promise> = (model_id, loading) => { +const pauseDownload: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading, +) => { return put(`${prefix}/model/${model_id}/pause_download`, undefined, {}, loading) } -const deleteModel: ( - model_id: string, - loading?: Ref, -) => Promise> = (model_id, loading) => { +const deleteModel: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading, +) => { return del(`${prefix}/model/${model_id}`, undefined, {}, loading) } export default { @@ -128,4 +182,9 @@ export default { pauseDownload, getModelParamsForm, updateModelParamsForm, + getApplicationRerankerModel, + getApplicationSTTModel, + getApplicationTTSModel, + getApplicationImageModel, + getApplicationTTIModel, } diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 27f948a75..6fd928846 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -5,7 +5,7 @@ :class="type" :style="{ height: firsUserInput ? '100%' : undefined, - paddingBottom: applicationDetails.disclaimer ? '20px' : 0 + paddingBottom: applicationDetails.disclaimer ? '20px' : 0, }" >
({}), available: true, - type: 'ai-chat' - } + type: 'ai-chat', + }, ) const emit = defineEmits(['refresh', 'scroll']) const { application, common } = useStore() @@ -163,13 +163,13 @@ const initialApiFormData = ref({}) const isUserInput = computed( () => props.applicationDetails.work_flow?.nodes?.filter((v: any) => v.id === 'base-node')[0] - .properties.user_input_field_list.length > 0 + .properties.user_input_field_list.length > 0, ) const isAPIInput = computed( () => props.type === 'debug-ai-chat' && props.applicationDetails.work_flow?.nodes?.filter((v: any) => v.id === 'base-node')[0] - .properties.api_input_field_list.length > 0 + .properties.api_input_field_list.length > 0, ) const showUserInputContent = computed(() => { return ( @@ -192,7 +192,7 @@ watch( } } }, - { deep: true, immediate: true } + { deep: true, immediate: true }, ) watch( @@ -200,7 +200,7 @@ watch( () => { chartOpenId.value = '' }, - { deep: true } + { deep: true }, ) watch( @@ -209,8 +209,8 @@ watch( chatList.value = value ? value : [] }, { - immediate: true - } + immediate: true, + }, ) const toggleUserInput = () => { @@ -292,9 +292,9 @@ const handleDebounceClick = debounce((val, other_params_data?: any, chat?: chatT */ const openChatId: () => Promise = () => { const obj = props.applicationDetails - if (props.appId) { + if (props.type === 'debug-ai-chat') { return applicationApi - .getChatOpen(props.appId) + .open(obj.id) .then((res) => { chartOpenId.value = res.data return res.data @@ -308,21 +308,7 @@ const openChatId: () => Promise = () => { return Promise.reject(res) }) } else { - if (isWorkFlow(obj.type)) { - const submitObj = { - work_flow: obj.work_flow, - user_id: obj.user - } - return applicationApi.postWorkflowChatOpen(submitObj).then((res) => { - chartOpenId.value = res.data - return res.data - }) - } else { - return applicationApi.postChatOpen(obj).then((res) => { - chartOpenId.value = res.data - return res.data - }) - } + return Promise.reject('暂不支持') } } /** @@ -447,8 +433,8 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_para audio_list: other_params_data && other_params_data.audio_list ? other_params_data.audio_list : [], other_list: - other_params_data && other_params_data.other_list ? other_params_data.other_list : [] - } + other_params_data && other_params_data.other_list ? other_params_data.other_list : [], + }, }) chatList.value.push(chat) ChatManagement.addChatRecord(chat, 50, loading) @@ -470,16 +456,17 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_para } else { const obj = { message: chat.problem_text, + stream: true, re_chat: re_chat || false, ...other_params_data, form_data: { ...form_data.value, - ...api_form_data.value - } + ...api_form_data.value, + }, } // 对话 applicationApi - .postChatMessage(chartOpenId.value, obj) + .chat(chartOpenId.value, obj) .then((response) => { if (response.status === 401) { application @@ -504,7 +491,7 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_para const write = getWrite( chat, reader, - response.headers.get('Content-Type') !== 'application/json' + response.headers.get('Content-Type') !== 'application/json', ) return reader.read().then(write) } @@ -530,14 +517,16 @@ function chatMessage(chat?: any, problem?: string, re_chat?: boolean, other_para */ function getSourceDetail(row: any) { if (row.record_id) { - chatLogApi.getRecordDetail(id || props.appId, row.chat_id, row.record_id, loading).then((res) => { - const exclude_keys = ['answer_text', 'id', 'answer_text_list'] - Object.keys(res.data).forEach((key) => { - if (!exclude_keys.includes(key)) { - row[key] = res.data[key] - } + chatLogApi + .getChatRecordDetails(id || props.appId, row.chat_id, row.record_id, loading) + .then((res) => { + const exclude_keys = ['answer_text', 'id', 'answer_text_list'] + Object.keys(res.data).forEach((key) => { + if (!exclude_keys.includes(key)) { + row[key] = res.data[key] + } + }) }) - }) } return true } @@ -617,11 +606,11 @@ watch( () => { handleScroll() }, - { deep: true, immediate: true } + { deep: true, immediate: true }, ) defineExpose({ - setScrollBottom + setScrollBottom, })