From 23c7269231fb63b70c87a1e13c6838891d70f925 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Tue, 22 Oct 2024 18:29:13 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=AF=B9=E8=AF=9D=E7=A7=98?= =?UTF-8?q?=E9=92=A5=E6=A0=A1=E9=AA=8C=20(#1430)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/application_serializers.py | 38 +++- apps/application/views/application_views.py | 47 ++--- .../auth/handle/impl/public_access_token.py | 18 +- apps/common/util/common.py | 13 ++ ui/src/api/application.ts | 16 +- ui/src/components/ai-chat/index.vue | 163 +++++++++--------- ui/src/request/index.ts | 3 + ui/src/stores/modules/application.ts | 8 +- ui/src/views/chat/auth/component/password.vue | 83 +++++++++ ui/src/views/chat/auth/index.vue | 56 ++++++ ui/src/views/chat/base/index.vue | 46 ++--- ui/src/views/chat/embed/index.vue | 133 +++----------- ui/src/views/chat/index.vue | 73 ++++++-- ui/src/views/chat/pc/index.vue | 136 +++------------ 14 files changed, 455 insertions(+), 378 deletions(-) create mode 100644 ui/src/views/chat/auth/component/password.vue create mode 100644 ui/src/views/chat/auth/index.vue diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index 005740b00..b59b0f10e 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -36,7 +36,7 @@ from common.db.sql_execute import select_list from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed from common.field.common import UploadedImageField from common.models.db_model_manage import DBModelManage -from common.util.common import valid_license +from common.util.common import valid_license, password_encrypt from common.util.field_message import ErrMessage from common.util.file_util import get_file_content from dataset.models import DataSet, Document, Image @@ -264,7 +264,9 @@ class ApplicationSerializer(serializers.Serializer): if work_flow is not None: for node in work_flow.get('nodes', []): if node['id'] == 'base-node': - input_field_list = node.get('properties', {}).get('api_input_field_list', node.get('properties', {}).get('input_field_list', [])) + input_field_list = node.get('properties', {}).get('api_input_field_list', + node.get('properties', {}).get( + 'input_field_list', [])) if input_field_list is not None: for field in input_field_list: if field['assignment_method'] == 'api_input' and field['variable'] in params: @@ -352,6 +354,8 @@ class ApplicationSerializer(serializers.Serializer): class Authentication(serializers.Serializer): access_token = serializers.CharField(required=True, error_messages=ErrMessage.char("access_token")) + authentication_value = serializers.JSONField(required=False, allow_null=True, + error_messages=ErrMessage.char("认证信息")) def auth(self, request, with_valid=True): token = request.META.get('HTTP_AUTHORIZATION') @@ -366,21 +370,47 @@ class ApplicationSerializer(serializers.Serializer): self.is_valid(raise_exception=True) access_token = self.data.get("access_token") application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() + authentication_value = self.data.get('authentication_value', None) + authentication = {} if application_access_token is not None and application_access_token.is_active: if token_details is not None and 'client_id' in token_details and token_details.get( 'client_id') is not None: client_id = token_details.get('client_id') + authentication = {'type': token_details.get('type'), + 'value': token_details.get('value')} else: client_id = str(uuid.uuid1()) + if authentication_value is not None: + # 认证用户token + self.auth_authentication_value(authentication_value, str(application_access_token.application_id)) + authentication = {'type': authentication_value.get('type'), + 'value': password_encrypt(authentication_value.get('value'))} token = signing.dumps({'application_id': str(application_access_token.application_id), 'user_id': str(application_access_token.application.user.id), 'access_token': application_access_token.access_token, 'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value, - 'client_id': client_id}) + 'client_id': client_id + , **authentication}) return token else: raise NotFound404(404, "无效的access_token") + def auth_authentication_value(self, authentication_value, application_id): + application_setting_model = DBModelManage.get_model('application_setting') + xpack_cache = DBModelManage.get_model('xpack_cache') + X_PACK_LICENSE_IS_VALID = False if xpack_cache is None else xpack_cache.get('XPACK_LICENSE_IS_VALID', False) + if application_setting_model is not None and X_PACK_LICENSE_IS_VALID: + application_setting = QuerySet(application_setting_model).filter(application_id=application_id).first() + if application_setting.authentication and authentication_value is not None: + if authentication_value.get('type') == 'password': + if not self.auth_password(authentication_value, application_setting.authentication_value): + raise AppApiException(1005, "密码错误") + return True + + @staticmethod + def auth_password(source_authentication_value, authentication_value): + return source_authentication_value.get('value') == authentication_value.get('value') + class Edit(serializers.Serializer): name = serializers.CharField(required=False, max_length=64, min_length=1, error_messages=ErrMessage.char("应用名称")) @@ -762,6 +792,8 @@ class ApplicationSerializer(serializers.Serializer): 'avatar': application_setting.avatar, 'float_icon': application_setting.float_icon, 'authentication': application_setting.authentication, + 'authentication_type': application_setting.authentication_value.get( + 'type', 'password'), 'disclaimer': application_setting.disclaimer, 'disclaimer_value': application_setting.disclaimer_value, 'custom_theme': application_setting.custom_theme, diff --git a/apps/application/views/application_views.py b/apps/application/views/application_views.py index 8a7f4931c..64b6c367b 100644 --- a/apps/application/views/application_views.py +++ b/apps/application/views/application_views.py @@ -376,7 +376,9 @@ class Application(APIView): security=[]) def post(self, request: Request): return result.success( - ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token")}).auth( + ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token"), + 'authentication_value': request.data.get( + 'authentication_value')}).auth( request), headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", "Access-Control-Allow-Methods": "POST", @@ -539,12 +541,13 @@ class Application(APIView): authentication_classes = [TokenAuth] @action(methods=['POST'], detail=False) - @has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN], - [lambda r, keywords: Permission(group=Group.APPLICATION, - operate=Operate.USE, - dynamic_tag=keywords.get( - 'application_id'))], - compare=CompareConstants.AND)) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=keywords.get( + 'application_id'))], + compare=CompareConstants.AND)) def post(self, request: Request, application_id: str): return result.success( ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}) @@ -554,31 +557,33 @@ class Application(APIView): authentication_classes = [TokenAuth] @action(methods=['POST'], detail=False) - @has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN], - [lambda r, keywords: Permission(group=Group.APPLICATION, - operate=Operate.USE, - dynamic_tag=keywords.get( - 'application_id'))], - compare=CompareConstants.AND)) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=keywords.get( + 'application_id'))], + compare=CompareConstants.AND)) def post(self, request: Request, application_id: str): byte_data = ApplicationSerializer.Operate( data={'application_id': application_id, 'user_id': request.user.id}).text_to_speech( request.data.get('text')) return HttpResponse(byte_data, status=200, headers={'Content-Type': 'audio/mp3', - 'Content-Disposition': 'attachment; filename="abc.mp3"'}) + 'Content-Disposition': 'attachment; filename="abc.mp3"'}) class PlayDemoText(APIView): authentication_classes = [TokenAuth] @action(methods=['POST'], detail=False) - @has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN], - [lambda r, keywords: Permission(group=Group.APPLICATION, - operate=Operate.USE, - dynamic_tag=keywords.get( - 'application_id'))], - compare=CompareConstants.AND)) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=keywords.get( + 'application_id'))], + compare=CompareConstants.AND)) def post(self, request: Request, application_id: str): byte_data = ApplicationSerializer.Operate( data={'application_id': application_id, 'user_id': request.user.id}).play_demo_text(request.data) return HttpResponse(byte_data, status=200, headers={'Content-Type': 'audio/mp3', - 'Content-Disposition': 'attachment; filename="abc.mp3"'}) + 'Content-Disposition': 'attachment; filename="abc.mp3"'}) diff --git a/apps/common/auth/handle/impl/public_access_token.py b/apps/common/auth/handle/impl/public_access_token.py index 1655187a8..250f05efb 100644 --- a/apps/common/auth/handle/impl/public_access_token.py +++ b/apps/common/auth/handle/impl/public_access_token.py @@ -12,7 +12,9 @@ from application.models.api_key_model import ApplicationAccessToken from common.auth.handle.auth_base_handle import AuthBaseHandle from common.constants.authentication_type import AuthenticationType from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, Auth -from common.exception.app_exception import AppAuthenticationFailed +from common.exception.app_exception import AppAuthenticationFailed, ChatException +from common.models.db_model_manage import DBModelManage +from common.util.common import password_encrypt class PublicAccessToken(AuthBaseHandle): @@ -29,6 +31,20 @@ class PublicAccessToken(AuthBaseHandle): auth_details = get_token_details() application_access_token = QuerySet(ApplicationAccessToken).filter( application_id=auth_details.get('application_id')).first() + application_setting_model = DBModelManage.get_model('application_setting') + xpack_cache = DBModelManage.get_model('xpack_cache') + X_PACK_LICENSE_IS_VALID = False if xpack_cache is None else xpack_cache.get('XPACK_LICENSE_IS_VALID', False) + if application_setting_model is not None and X_PACK_LICENSE_IS_VALID: + application_setting = QuerySet(application_setting_model).filter(application_id=str( + application_access_token.application_id)).first() + if application_setting.authentication: + authentication = auth_details.get('authentication', {}) + if authentication is None: + authentication = {} + if application_setting.authentication_value.get('type') != authentication.get( + 'type') or password_encrypt( + application_setting.authentication_value.get('value')) != authentication.get('value'): + raise ChatException(1002, "身份验证信息不正确") if application_access_token is None: raise AppAuthenticationFailed(1002, "身份验证信息不正确") if not application_access_token.is_active: diff --git a/apps/common/util/common.py b/apps/common/util/common.py index f1586bb95..cbf6b0011 100644 --- a/apps/common/util/common.py +++ b/apps/common/util/common.py @@ -6,6 +6,7 @@ @date:2023/10/16 16:42 @desc: """ +import hashlib import importlib from functools import reduce from typing import Dict, List @@ -62,6 +63,18 @@ def flat_map(array: List[List]): return result +def password_encrypt(raw_password): + """ + 密码 md5加密 + :param raw_password: 密码 + :return: 加密后密码 + """ + md5 = hashlib.md5() # 2,实例化md5() 方法 + md5.update(raw_password.encode()) # 3,对字符串的字节类型加密 + result = md5.hexdigest() # 4,加密 + return result + + def post(post_function): def inner(func): def run(*args, **kwargs): diff --git a/ui/src/api/application.ts b/ui/src/api/application.ts index 4672d6f95..c06264997 100644 --- a/ui/src/api/application.ts +++ b/ui/src/api/application.ts @@ -122,11 +122,17 @@ const putAccessToken: ( "access_token": "string" } */ -const postAppAuthentication: (access_token: string, loading?: Ref) => Promise = ( - access_token, - loading -) => { - return post(`${prefix}/authentication`, { access_token }, undefined, loading) +const postAppAuthentication: ( + access_token: string, + loading?: Ref, + authentication_value?: any +) => Promise = (access_token, loading, authentication_value) => { + return post( + `${prefix}/authentication`, + { access_token: access_token, authentication_value }, + undefined, + loading + ) } /** diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 0ca571259..fdb28d4d4 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -365,8 +365,7 @@ function handleInputFieldList() { ?.filter((v: any) => v.id === 'base-node') .map((v: any) => { inputFieldList.value = v.properties.user_input_field_list - ? v.properties.user_input_field_list - .map((v: any) => { + ? v.properties.user_input_field_list.map((v: any) => { switch (v.type) { case 'input': return { @@ -404,51 +403,51 @@ function handleInputFieldList() { return v } }) - : v.properties.input_field_list ? v.properties.input_field_list - .filter((v: any) => v.assignment_method === 'user_input') - .map((v: any) => { - switch (v.type) { - case 'input': - return { - field: v.variable, - input_type: 'TextInput', - label: v.name, - default_value: default_value[v.variable], - required: v.is_required - } - case 'select': - return { - field: v.variable, - input_type: 'SingleSelect', - label: v.name, - default_value: default_value[v.variable], - required: v.is_required, - option_list: v.optionList.map((o: any) => { - return { key: o, value: o } - }) - } - case 'date': - return { - field: v.variable, - input_type: 'DatePicker', - label: v.name, - default_value: default_value[v.variable], - required: v.is_required, - attrs: { - format: 'YYYY-MM-DD HH:mm:ss', - 'value-format': 'YYYY-MM-DD HH:mm:ss', - type: 'datetime' + : v.properties.input_field_list + ? v.properties.input_field_list + .filter((v: any) => v.assignment_method === 'user_input') + .map((v: any) => { + switch (v.type) { + case 'input': + return { + field: v.variable, + input_type: 'TextInput', + label: v.name, + default_value: default_value[v.variable], + required: v.is_required } - } - default: - break - } - }) + case 'select': + return { + field: v.variable, + input_type: 'SingleSelect', + label: v.name, + default_value: default_value[v.variable], + required: v.is_required, + option_list: v.optionList.map((o: any) => { + return { key: o, value: o } + }) + } + case 'date': + return { + field: v.variable, + input_type: 'DatePicker', + label: v.name, + default_value: default_value[v.variable], + required: v.is_required, + attrs: { + format: 'YYYY-MM-DD HH:mm:ss', + 'value-format': 'YYYY-MM-DD HH:mm:ss', + type: 'datetime' + } + } + default: + break + } + }) : [] apiInputFieldList.value = v.properties.api_input_field_list - ? v.properties.api_input_field_list - .map((v: any) => { + ? v.properties.api_input_field_list.map((v: any) => { switch (v.type) { case 'input': return { @@ -488,45 +487,45 @@ function handleInputFieldList() { }) : v.properties.input_field_list ? v.properties.input_field_list - .filter((v: any) => v.assignment_method === 'api_input') - .map((v: any) => { - switch (v.type) { - case 'input': - return { - field: v.variable, - input_type: 'TextInput', - label: v.name, - default_value: default_value[v.variable], - required: v.is_required - } - case 'select': - return { - field: v.variable, - input_type: 'SingleSelect', - label: v.name, - default_value: default_value[v.variable], - required: v.is_required, - option_list: v.optionList.map((o: any) => { - return { key: o, value: o } - }) - } - case 'date': - return { - field: v.variable, - input_type: 'DatePicker', - label: v.name, - default_value: default_value[v.variable], - required: v.is_required, - attrs: { - format: 'YYYY-MM-DD HH:mm:ss', - 'value-format': 'YYYY-MM-DD HH:mm:ss', - type: 'datetime' + .filter((v: any) => v.assignment_method === 'api_input') + .map((v: any) => { + switch (v.type) { + case 'input': + return { + field: v.variable, + input_type: 'TextInput', + label: v.name, + default_value: default_value[v.variable], + required: v.is_required } - } - default: - break - } - }) + case 'select': + return { + field: v.variable, + input_type: 'SingleSelect', + label: v.name, + default_value: default_value[v.variable], + required: v.is_required, + option_list: v.optionList.map((o: any) => { + return { key: o, value: o } + }) + } + case 'date': + return { + field: v.variable, + input_type: 'DatePicker', + label: v.name, + default_value: default_value[v.variable], + required: v.is_required, + attrs: { + format: 'YYYY-MM-DD HH:mm:ss', + 'value-format': 'YYYY-MM-DD HH:mm:ss', + type: 'datetime' + } + } + default: + break + } + }) : [] }) } @@ -912,7 +911,7 @@ const mediaRecorderStatus = ref(true) const startRecording = async () => { try { // 取消录音控制台日志 - Recorder.CLog=function(){} + Recorder.CLog = function () {} mediaRecorderStatus.value = false handleTimeChange() mediaRecorder.value = new Recorder({ diff --git a/ui/src/request/index.ts b/ui/src/request/index.ts index 94760aeb4..4c8bbed16 100644 --- a/ui/src/request/index.ts +++ b/ui/src/request/index.ts @@ -40,6 +40,9 @@ instance.interceptors.response.use( (response: any) => { if (response.data) { if (response.data.code !== 200 && !(response.data instanceof Blob)) { + if (response.config.url.includes('/application/authentication')) { + return Promise.reject(response.data) + } if ( !response.config.url.includes('/valid') && !response.config.url.includes('/function_lib/debug') diff --git a/ui/src/stores/modules/application.ts b/ui/src/stores/modules/application.ts index 5c04650ec..40588fb3a 100644 --- a/ui/src/stores/modules/application.ts +++ b/ui/src/stores/modules/application.ts @@ -89,10 +89,14 @@ const useApplicationStore = defineStore({ }) }, - async asyncAppAuthentication(token: string, loading?: Ref) { + async asyncAppAuthentication( + token: string, + loading?: Ref, + authentication_value?: any + ) { return new Promise((resolve, reject) => { applicationApi - .postAppAuthentication(token, loading) + .postAppAuthentication(token, loading, authentication_value) .then((res) => { localStorage.setItem('accessToken', res.data) sessionStorage.setItem('accessToken', res.data) diff --git a/ui/src/views/chat/auth/component/password.vue b/ui/src/views/chat/auth/component/password.vue new file mode 100644 index 000000000..864dc383c --- /dev/null +++ b/ui/src/views/chat/auth/component/password.vue @@ -0,0 +1,83 @@ + + + diff --git a/ui/src/views/chat/auth/index.vue b/ui/src/views/chat/auth/index.vue new file mode 100644 index 000000000..1d17c4b5a --- /dev/null +++ b/ui/src/views/chat/auth/index.vue @@ -0,0 +1,56 @@ + + + diff --git a/ui/src/views/chat/base/index.vue b/ui/src/views/chat/base/index.vue index f562164f5..907dccb9d 100644 --- a/ui/src/views/chat/base/index.vue +++ b/ui/src/views/chat/base/index.vue @@ -34,49 +34,27 @@ diff --git a/ui/src/views/chat/pc/index.vue b/ui/src/views/chat/pc/index.vue index 567527e75..f4b4d4eb4 100644 --- a/ui/src/views/chat/pc/index.vue +++ b/ui/src/views/chat/pc/index.vue @@ -1,30 +1,5 @@