From 165de1ad739e0b9ea71ab6a02dfa5d2c12d03cf4 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Fri, 6 Jun 2025 22:28:21 +0800 Subject: [PATCH] feat: chat authentication (#3206) --- apps/application/flow/i_step_node.py | 4 +- .../impl/base_search_dataset_node.py | 10 +- ...hat_chatrecord_workflowversion_and_more.py | 54 ++++--- .../application/models/application_api_key.py | 13 -- apps/application/models/application_chat.py | 26 +++- apps/application/views/application_api_key.py | 5 +- apps/application/views/application_version.py | 3 +- apps/chat/api/chat_authentication_api.py | 24 +++ apps/chat/serializers/chat_authentication.py | 147 ++++++++++++++++++ apps/chat/urls.py | 2 + apps/chat/views/__init__.py | 1 + apps/chat/views/chat.py | 70 +++++++++ apps/common/auth/common.py | 64 ++++++++ .../handle/impl/chat_anonymous_user_token.py | 55 +++++++ apps/common/auth/handle/impl/user_token.py | 3 +- apps/common/constants/permission_constants.py | 4 +- apps/maxkb/settings/auth.py | 2 + apps/maxkb/settings/base.py | 1 + 18 files changed, 440 insertions(+), 48 deletions(-) create mode 100644 apps/chat/api/chat_authentication_api.py create mode 100644 apps/chat/serializers/chat_authentication.py create mode 100644 apps/chat/views/chat.py create mode 100644 apps/common/auth/common.py create mode 100644 apps/common/auth/handle/impl/chat_anonymous_user_token.py diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 045326c90..701d5d78d 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -19,7 +19,7 @@ from rest_framework.exceptions import ValidationError, ErrorDetail from application.flow.common import Answer, NodeChunk from application.models import ChatRecord -from application.models import ApplicationPublicAccessClient +from application.models import ApplicationChatClientStats from common.constants.authentication_type import AuthenticationType from common.field.common import InstanceField @@ -89,7 +89,7 @@ class WorkFlowPostHandler: chat_cache.set(chat_id, self.chat_info, timeout=60 * 30) if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: - application_public_access_client = (QuerySet(ApplicationPublicAccessClient) + application_public_access_client = (QuerySet(ApplicationChatClientStats) .filter(client_id=self.client_id, application_id=self.chat_info.application.id).first()) if application_public_access_client is not None: diff --git a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py index 74e69c06d..3edbd77fe 100644 --- a/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py +++ b/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -9,19 +9,17 @@ import os from typing import List, Dict -from django.db.models import QuerySet from django.db import connection +from django.db.models import QuerySet + from application.flow.i_step_node import NodeResult from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode from common.config.embedding_config import VectorStore from common.db.search import native_search from common.utils.common import get_file_content -from knowledge.models import Document, Paragraph, Knowledge - -from models_provider.tools import get_model_instance_by_model_user_id +from knowledge.models import Document, Paragraph, Knowledge, SearchMode from maxkb.conf import PROJECT_DIR - -SearchMode = None +from models_provider.tools import get_model_instance_by_model_user_id def get_embedding_id(dataset_id_list): diff --git a/apps/application/migrations/0002_chat_chatrecord_workflowversion_and_more.py b/apps/application/migrations/0002_chat_chatrecord_workflowversion_and_more.py index 0d92af56f..3c1b72152 100644 --- a/apps/application/migrations/0002_chat_chatrecord_workflowversion_and_more.py +++ b/apps/application/migrations/0002_chat_chatrecord_workflowversion_and_more.py @@ -1,16 +1,16 @@ # Generated by Django 5.2 on 2025-06-04 11:57 -import application.models.application_chat -import common.encoder.encoder +import uuid + import django.contrib.postgres.fields import django.db.models.deletion -import uuid import uuid_utils.compat from django.db import migrations, models +import common.encoder.encoder + class Migration(migrations.Migration): - dependencies = [ ('application', '0001_initial'), ] @@ -21,12 +21,13 @@ class Migration(migrations.Migration): fields=[ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), - ('id', models.UUIDField(default=uuid.UUID('01973acd-fe4c-7fd1-94a8-f7cd668de562'), editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('id', models.UUIDField(default=uuid.UUID('01973acd-fe4c-7fd1-94a8-f7cd668de562'), editable=False, + primary_key=True, serialize=False, verbose_name='主键id')), ('abstract', models.CharField(max_length=1024, verbose_name='摘要')), - ('asker', models.JSONField(default=application.models.application_chat.default_asker, encoder=common.encoder.encoder.SystemEncoder, verbose_name='访问者')), ('client_id', models.UUIDField(default=None, null=True, verbose_name='客户端id')), ('is_deleted', models.BooleanField(default=False, verbose_name='')), - ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')), + ('application', + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')), ], options={ 'db_table': 'application_chat', @@ -37,16 +38,25 @@ class Migration(migrations.Migration): fields=[ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), - ('id', models.UUIDField(default=uuid_utils.compat.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), - ('vote_status', models.CharField(choices=[('-1', '未投票'), ('0', '赞同'), ('1', '反对')], default='-1', max_length=10, verbose_name='投票')), + ('id', + models.UUIDField(default=uuid_utils.compat.uuid1, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), + ('vote_status', + models.CharField(choices=[('-1', '未投票'), ('0', '赞同'), ('1', '反对')], default='-1', max_length=10, + verbose_name='投票')), ('problem_text', models.CharField(max_length=10240, verbose_name='问题')), ('answer_text', models.CharField(max_length=40960, verbose_name='答案')), - ('answer_text_list', django.contrib.postgres.fields.ArrayField(base_field=models.JSONField(), default=list, size=None, verbose_name='改进标注列表')), + ('answer_text_list', + django.contrib.postgres.fields.ArrayField(base_field=models.JSONField(), default=list, size=None, + verbose_name='改进标注列表')), ('message_tokens', models.IntegerField(default=0, verbose_name='请求token数量')), ('answer_tokens', models.IntegerField(default=0, verbose_name='响应token数量')), ('const', models.IntegerField(default=0, verbose_name='总费用')), - ('details', models.JSONField(default=dict, encoder=common.encoder.encoder.SystemEncoder, verbose_name='对话详情')), - ('improve_paragraph_id_list', django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, size=None, verbose_name='改进标注列表')), + ('details', + models.JSONField(default=dict, encoder=common.encoder.encoder.SystemEncoder, verbose_name='对话详情')), + ('improve_paragraph_id_list', + django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, + size=None, verbose_name='改进标注列表')), ('run_time', models.FloatField(default=0, verbose_name='运行时长')), ('index', models.IntegerField(verbose_name='对话下标')), ('chat', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.chat')), @@ -60,13 +70,17 @@ class Migration(migrations.Migration): fields=[ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), - ('id', models.UUIDField(default=uuid_utils.compat.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), - ('workspace_id', models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')), + ('id', + models.UUIDField(default=uuid_utils.compat.uuid1, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), + ('workspace_id', + models.CharField(db_index=True, default='default', max_length=64, verbose_name='工作空间id')), ('name', models.CharField(default='', max_length=128, verbose_name='版本名称')), ('publish_user_id', models.UUIDField(default=None, null=True, verbose_name='发布者id')), ('publish_user_name', models.CharField(default='', max_length=128, verbose_name='发布者名称')), ('work_flow', models.JSONField(default=dict, verbose_name='工作流数据')), - ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')), + ('application', + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')), ], options={ 'db_table': 'application_work_flow_version', @@ -77,16 +91,20 @@ class Migration(migrations.Migration): fields=[ ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), - ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), ('client_id', models.UUIDField(default=uuid.uuid1, verbose_name='公共访问链接客户端id')), ('client_type', models.CharField(max_length=64, verbose_name='客户端类型')), ('access_num', models.IntegerField(default=0, verbose_name='访问总次数次数')), ('intraday_access_num', models.IntegerField(default=0, verbose_name='当日访问次数')), - ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')), + ('application', + models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', + verbose_name='应用id')), ], options={ 'db_table': 'application_public_access_client', - 'indexes': [models.Index(fields=['application_id', 'client_id'], name='application_applica_8aaf45_idx')], + 'indexes': [ + models.Index(fields=['application_id', 'client_id'], name='application_applica_8aaf45_idx')], }, ), ] diff --git a/apps/application/models/application_api_key.py b/apps/application/models/application_api_key.py index 8970ba322..5bf139a87 100644 --- a/apps/application/models/application_api_key.py +++ b/apps/application/models/application_api_key.py @@ -24,16 +24,3 @@ class ApplicationApiKey(AppModelMixin): db_table = "application_api_key" -class ApplicationPublicAccessClient(AppModelMixin): - id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") - client_id = models.UUIDField(max_length=128, default=uuid.uuid1, verbose_name="公共访问链接客户端id") - client_type = models.CharField(max_length=64, verbose_name="客户端类型") - application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id") - access_num = models.IntegerField(default=0, verbose_name="访问总次数次数") - intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数") - - class Meta: - db_table = "application_public_access_client" - indexes = [ - models.Index(fields=['application_id', 'client_id']), - ] diff --git a/apps/application/models/application_chat.py b/apps/application/models/application_chat.py index 428658f24..0846a9611 100644 --- a/apps/application/models/application_chat.py +++ b/apps/application/models/application_chat.py @@ -17,17 +17,20 @@ from common.encoder.encoder import SystemEncoder from common.mixins.app_model_mixin import AppModelMixin -def default_asker(): - return {'user_name': '游客'} +class ClientType(models.TextChoices): + ANONYMOUS_USER = "ANONYMOUS_USER", '匿名用户' + CHAT_USER = "CHAT_USER", "对话用户" + SYSTEM_API_KEY = "SYSTEM_API_KEY", "系统API_KEY" + APPLICATION_API_KEY = "APPLICATION_API_KEY", "应用API_KEY" class Chat(AppModelMixin): id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7(), editable=False, verbose_name="主键id") application = models.ForeignKey(Application, on_delete=models.CASCADE) abstract = models.CharField(max_length=1024, verbose_name="摘要") - asker = models.JSONField(verbose_name="访问者", default=default_asker, encoder=SystemEncoder) client_id = models.UUIDField(verbose_name="客户端id", default=None, null=True) - is_deleted = models.BooleanField(verbose_name="", default=False) + client_type = models.CharField(max_length=64, verbose_name="客户端类型", choices=ClientType.choices) + is_deleted = models.BooleanField(verbose_name="逻辑删除", default=False) class Meta: db_table = "application_chat" @@ -80,3 +83,18 @@ class ChatRecord(AppModelMixin): class Meta: db_table = "application_chat_record" + + +class ApplicationChatClientStats(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") + client_id = models.UUIDField(max_length=128, default=uuid.uuid7, verbose_name="公共访问链接客户端id") + client_type = models.CharField(max_length=64, verbose_name="客户端类型", choices=ClientType.choices) + application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id") + access_num = models.IntegerField(default=0, verbose_name="访问总次数次数") + intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数") + + class Meta: + db_table = "application_chat_client_stats" + indexes = [ + models.Index(fields=['application_id', 'client_id']), + ] diff --git a/apps/application/views/application_api_key.py b/apps/application/views/application_api_key.py index 5f961a3cb..3a1a9bb9c 100644 --- a/apps/application/views/application_api_key.py +++ b/apps/application/views/application_api_key.py @@ -34,10 +34,11 @@ class ApplicationKey(APIView): parameters=ApplicationKeyCreateAPI.get_parameters(), tags=[_('Application Api Key')] # type: ignore ) + @log(menu='Application', operate="Add ApiKey", get_operation_object=lambda r, k: get_application_operation_object(k.get('application_api_key_id'))) @has_permissions(PermissionConstants.APPLICATION_OVERVIEW_API_KEY.get_workspace_application_permission()) - def post(self, request: Request, application_id: str, workspace_id: str): + def post(self, request: Request, workspace_id: str, application_id: str): return result.success(ApplicationKeySerializer( data={'application_id': application_id, 'user_id': request.user.id, 'workspace_id': workspace_id}).generate()) @@ -51,7 +52,7 @@ class ApplicationKey(APIView): tags=[_('Application Api Key')] # type: ignore ) @has_permissions(PermissionConstants.APPLICATION_OVERVIEW_API_KEY.get_workspace_application_permission()) - def get(self, request: Request, application_id: str, workspace_id: str): + def get(self, request: Request, workspace_id: str, application_id: str ): return result, success(ApplicationKeySerializer( data={'application_id': application_id, 'user_id': request.user.id, 'workspace_id': workspace_id}).list()) diff --git a/apps/application/views/application_version.py b/apps/application/views/application_version.py index 4fa731b7d..3aa105c02 100644 --- a/apps/application/views/application_version.py +++ b/apps/application/views/application_version.py @@ -12,7 +12,7 @@ from rest_framework.request import Request from rest_framework.views import APIView from application.api.application_version import ApplicationVersionListAPI, ApplicationVersionPageAPI, \ - ApplicationVersionAPI, ApplicationVersionOperateAPI + ApplicationVersionOperateAPI from application.serializers.application_version import ApplicationVersionSerializer from application.views import get_application_operation_object from common import result @@ -90,6 +90,7 @@ class ApplicationVersionView(APIView): responses=ApplicationVersionOperateAPI.get_response(), tags=[_('Application/Version')] # type: ignore ) + @has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission()) @log(menu='Application', operate="Modify application version information", get_operation_object=lambda r, k: get_application_operation_object(k.get('application_id'))) def put(self, request: Request, workspace_id: str, application_id: str, work_flow_version_id: str): diff --git a/apps/chat/api/chat_authentication_api.py b/apps/chat/api/chat_authentication_api.py new file mode 100644 index 000000000..cae9006a5 --- /dev/null +++ b/apps/chat/api/chat_authentication_api.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: chat_authentication_api.py + @date:2025/6/6 19:59 + @desc: +""" +from chat.serializers.chat_authentication import AuthenticationSerializer +from common.mixins.api_mixin import APIMixin + + +class ChatAuthenticationAPI(APIMixin): + @staticmethod + def get_request(): + return AuthenticationSerializer() + + @staticmethod + def get_parameters(): + pass + + @staticmethod + def get_response(): + pass diff --git a/apps/chat/serializers/chat_authentication.py b/apps/chat/serializers/chat_authentication.py new file mode 100644 index 000000000..92dfab92c --- /dev/null +++ b/apps/chat/serializers/chat_authentication.py @@ -0,0 +1,147 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: ChatAuthentication.py + @date:2025/6/6 13:48 + @desc: +""" +import uuid + +from django.core import signing +from django.core.cache import cache +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from application.models import ApplicationAccessToken, ClientType, Application, ApplicationTypeChoices, WorkFlowVersion +from application.serializers.application import ApplicationSerializerModel +from common.auth.common import ChatUserToken, ChatAuthentication + +from common.constants.authentication_type import AuthenticationType +from common.constants.cache_version import Cache_Version +from common.database_model_manage.database_model_manage import DatabaseModelManage +from common.exception.app_exception import NotFound404, AppApiException, AppUnauthorizedFailed + + +def auth(application_id, access_token, authentication_value, token_details): + client_id = token_details.get('client_id') + if client_id is None: + client_id = str(uuid.uuid1()) + _type = AuthenticationType.CHAT_ANONYMOUS_USER + if authentication_value is not None: + application_setting_model = DatabaseModelManage.get_model('application_setting') + if application_setting_model is not None: + application_setting = QuerySet(application_setting_model).filter(application_id=application_id).first() + if application_setting.authentication: + auth_type = application_setting.authentication_value.get('type') + auth_value = authentication_value.get(auth_type + '_value') + if auth_type == 'password': + if authentication_value.get('type') == 'password': + if auth_value == authentication_value.get(auth_type + '_value'): + return ChatUserToken(application_id, None, access_token, _type, ClientType.ANONYMOUS_USER, + client_id, ChatAuthentication(auth_type, True, True)) + else: + raise AppApiException(500, '认证方式不匹配') + return ChatUserToken(application_id, None, access_token, _type, ClientType.ANONYMOUS_USER, + client_id, ChatAuthentication(None, False, False)) + + +class AuthenticationSerializer(serializers.Serializer): + access_token = serializers.CharField(required=True, label=_("access_token")) + authentication_value = serializers.JSONField(required=False, allow_null=True, + label=_("Certification Information")) + + def auth(self, request, with_valid=True): + token = request.META.get('HTTP_AUTHORIZATION') + token_details = {} + try: + # 校验token + if token is not None: + token_details = signing.loads(token) + except Exception as e: + pass + if with_valid: + self.is_valid(raise_exception=True) + access_token = self.data.get("access_token") + application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() + authentication_value = self.data.get('authentication_value', None) + if application_access_token is not None and application_access_token.is_active: + chat_user_token = auth(application_access_token.application_id, access_token, authentication_value, + token_details) + + return chat_user_token.to_token() + else: + raise NotFound404(404, _("Invalid access_token")) + + +class ApplicationProfileSerializer(serializers.Serializer): + application_id = serializers.UUIDField(required=True, label=_("Application ID")) + + def profile(self, with_valid=True): + if with_valid: + self.is_valid() + application_id = self.data.get("application_id") + application = QuerySet(Application).get(id=application_id) + application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application.id).first() + if application_access_token is None: + raise AppUnauthorizedFailed(500, _("Illegal User")) + application_setting_model = DatabaseModelManage.get_model('application_setting') + if application.type == ApplicationTypeChoices.WORK_FLOW: + work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application.id).order_by( + '-create_time')[0:1].first() + if work_flow_version is not None: + application.work_flow = work_flow_version.work_flow + + license_is_valid = cache.get(Cache_Version.SYSTEM.get_key(key='license_is_valid'), + version=Cache_Version.SYSTEM.get_version()) + application_setting_dict = {} + if application_setting_model is not None and license_is_valid: + application_setting = QuerySet(application_setting_model).filter( + application_id=application_access_token.application_id).first() + if application_setting is not None: + custom_theme = getattr(application_setting, 'custom_theme', {}) + float_location = getattr(application_setting, 'float_location', {}) + if not custom_theme: + application_setting.custom_theme = { + 'theme_color': '', + 'header_font_color': '' + } + if not float_location: + application_setting.float_location = { + 'x': {'type': '', 'value': ''}, + 'y': {'type': '', 'value': ''} + } + application_setting_dict = {'show_source': application_access_token.show_source, + 'show_history': application_setting.show_history, + 'draggable': application_setting.draggable, + 'show_guide': application_setting.show_guide, + 'avatar': application_setting.avatar, + 'show_avatar': application_setting.show_avatar, + 'float_icon': application_setting.float_icon, + 'authentication': application_setting.authentication, + 'authentication_type': application_setting.authentication_value.get( + 'type', 'password'), + 'login_value': application_setting.authentication_value.get( + 'login_value', []), + 'disclaimer': application_setting.disclaimer, + 'disclaimer_value': application_setting.disclaimer_value, + 'custom_theme': application_setting.custom_theme, + 'user_avatar': application_setting.user_avatar, + 'show_user_avatar': application_setting.show_user_avatar, + 'float_location': application_setting.float_location} + return {**ApplicationSerializerModel(application).data, + 'stt_model_id': application.stt_model_id, + 'tts_model_id': application.tts_model_id, + 'stt_model_enable': application.stt_model_enable, + 'tts_model_enable': application.tts_model_enable, + 'tts_type': application.tts_type, + 'tts_autoplay': application.tts_autoplay, + 'stt_autosend': application.stt_autosend, + 'file_upload_enable': application.file_upload_enable, + 'file_upload_setting': application.file_upload_setting, + 'work_flow': {'nodes': [node for node in ((application.work_flow or {}).get('nodes', []) or []) if + node.get('id') == 'base-node']}, + 'show_source': application_access_token.show_source, + 'language': application_access_token.language, + **application_setting_dict} diff --git a/apps/chat/urls.py b/apps/chat/urls.py index e7d9314f1..0f0bc12a2 100644 --- a/apps/chat/urls.py +++ b/apps/chat/urls.py @@ -6,4 +6,6 @@ app_name = 'chat' urlpatterns = [ path('chat/embed', views.ChatEmbedView.as_view()), + path('application/authentication', views.Authentication.as_view()), + path('profile', views.ApplicationProfile.as_view()) ] diff --git a/apps/chat/views/__init__.py b/apps/chat/views/__init__.py index 203186ac9..9a1e1c14f 100644 --- a/apps/chat/views/__init__.py +++ b/apps/chat/views/__init__.py @@ -7,3 +7,4 @@ @desc: """ from .chat_embed import * +from .chat import * diff --git a/apps/chat/views/chat.py b/apps/chat/views/chat.py new file mode 100644 index 000000000..bda65408e --- /dev/null +++ b/apps/chat/views/chat.py @@ -0,0 +1,70 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: chat.py + @date:2025/6/6 11:18 + @desc: +""" +from django.http import HttpResponse +from django.utils.translation import gettext_lazy as _ +from drf_spectacular.utils import extend_schema +from rest_framework.request import Request +from rest_framework.views import APIView + +from chat.api.chat_authentication_api import ChatAuthenticationAPI +from chat.serializers.chat_authentication import AuthenticationSerializer, ApplicationProfileSerializer +from common.auth import TokenAuth +from common.exception.app_exception import AppAuthenticationFailed +from common.result import result + + +class Authentication(APIView): + def options(self, request, *args, **kwargs): + return HttpResponse( + headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", + "Access-Control-Allow-Methods": "POST", + "Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, ) + + @extend_schema( + methods=['POST'], + description=_('Application Certification'), + summary=_('Application Certification'), + operation_id=_('Application Certification'), # type: ignore + request=ChatAuthenticationAPI.get_request(), + responses=None, + tags=[_('Chat')] # type: ignore + ) + def post(self, request: Request): + return result.success( + AuthenticationSerializer(data={'access_token': request.data.get("access_token"), + 'authentication_value': request.data.get( + 'authentication_value')}).auth( + request), + headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", + "Access-Control-Allow-Methods": "POST", + "Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"} + ) + + +class ApplicationProfile(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['GET'], + description=_("Get application related information"), + summary=_("Get application related information"), + operation_id=_("Get application related information"), # type: ignore + request=None, + responses=None, + tags=[_('Chat')] # type: ignore + ) + def get(self, request: Request): + if 'application_id' in request.auth.keywords: + return result.success(ApplicationProfileSerializer( + data={'application_id': request.auth.keywords.get('application_id')}).profile()) + raise AppAuthenticationFailed(401, "身份异常") + + +class ChatView(APIView): + pass diff --git a/apps/common/auth/common.py b/apps/common/auth/common.py new file mode 100644 index 000000000..700ed12b4 --- /dev/null +++ b/apps/common/auth/common.py @@ -0,0 +1,64 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: common.py + @date:2025/6/6 19:55 + @desc: +""" +import json + +from django.core import signing + +from common.utils.rsa_util import encrypt, decrypt + + +class ChatAuthentication: + def __init__(self, auth_type: str | None, is_auth: bool, auth_passed: bool): + self.is_auth = is_auth + self.auth_passed = auth_passed + self.auth_type = auth_type + + def to_dict(self): + return {'is_auth': self.is_auth, 'auth_passed': self.auth_passed, 'auth_type': self.auth_type} + + def to_string(self): + return encrypt(json.dumps(self.to_dict())) + + @staticmethod + def new_instance(authentication: str): + auth = json.loads(decrypt(authentication)) + return ChatAuthentication(auth.get('auth_type'), auth.get('is_auth'), auth.get('auth_passed')) + + +class ChatUserToken: + def __init__(self, application_id, user_id, access_token, _type, client_type, client_id, + authentication: ChatAuthentication): + self.application_id = application_id + self.user_id = user_id, + self.access_token = access_token + self.type = _type + self.client_type = client_type + self.client_id = client_id + self.authentication = authentication + + def to_dict(self): + return { + 'application_id': str(self.application_id), + 'user_id': str(self.user_id), + 'access_token': self.access_token, + 'type': str(self.type.value), + 'client_type': str(self.client_type), + 'client_id': str(self.client_id), + 'authentication': self.authentication.to_string() + } + + def to_token(self): + return signing.dumps(self.to_dict()) + + @staticmethod + def new_instance(token_dict): + return ChatUserToken(token_dict.get('application_id'), token_dict.get('user_id'), + token_dict.get('access_token'), token_dict.get('type'), token_dict.get('client_type'), + token_dict.get('client_id'), + ChatAuthentication.new_instance(token_dict.get('authentication'))) diff --git a/apps/common/auth/handle/impl/chat_anonymous_user_token.py b/apps/common/auth/handle/impl/chat_anonymous_user_token.py new file mode 100644 index 000000000..0d8da19ce --- /dev/null +++ b/apps/common/auth/handle/impl/chat_anonymous_user_token.py @@ -0,0 +1,55 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: chat_anonymous_user_token.py + @date:2025/6/6 15:08 + @desc: +""" +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ + +from application.models import ApplicationAccessToken, ClientType +from common.auth.common import ChatUserToken + +from common.auth.handle.auth_base_handle import AuthBaseHandle +from common.constants.authentication_type import AuthenticationType +from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, Auth +from common.exception.app_exception import AppAuthenticationFailed, ChatException + + +class ChatAnonymousUserToken(AuthBaseHandle): + def support(self, request, token: str, get_token_details): + token_details = get_token_details() + if token_details is None: + return False + return ( + 'application_id' in token_details and + 'access_token' in token_details and + token_details.get('type') == AuthenticationType.CHAT_ANONYMOUS_USER.value) + + def handle(self, request, token: str, get_token_details): + auth_details = get_token_details() + chat_user_token = ChatUserToken.new_instance(auth_details) + application_id = chat_user_token.application_id + access_token = chat_user_token.access_token + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=application_id).first() + if application_access_token is None: + raise AppAuthenticationFailed(1002, _('Authentication information is incorrect')) + if not application_access_token.is_active: + raise AppAuthenticationFailed(1002, _('Authentication information is incorrect')) + if not application_access_token.access_token == access_token: + raise AppAuthenticationFailed(1002, _('Authentication information is incorrect')) + # 匿名用户 除了/api/application/profile 都需要校验是否开启了密码认证 + if request.path != '/api/application/profile': + if chat_user_token.authentication.is_auth and not chat_user_token.authentication.auth_passed: + raise ChatException(1002, _('Authentication information is incorrect')) + return None, Auth( + current_role_list=[RoleConstants.CHAT_ANONYMOUS_USER], + permission_list=[ + Permission(group=Group.APPLICATION, + operate=Operate.USE)], + application_id=application_access_token.application_id, + client_id=auth_details.get('client_id'), + client_type=ClientType.ANONYMOUS_USER) diff --git a/apps/common/auth/handle/impl/user_token.py b/apps/common/auth/handle/impl/user_token.py index 22b6b113a..5c15d11bb 100644 --- a/apps/common/auth/handle/impl/user_token.py +++ b/apps/common/auth/handle/impl/user_token.py @@ -15,6 +15,7 @@ from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ from common.auth.handle.auth_base_handle import AuthBaseHandle +from common.constants.authentication_type import AuthenticationType from common.constants.cache_version import Cache_Version from common.constants.permission_constants import Auth, PermissionConstants, ResourcePermissionGroup, \ get_permission_list_by_resource_group, ResourceAuthType, \ @@ -233,7 +234,7 @@ class UserToken(AuthBaseHandle): auth_details = get_token_details() if auth_details is None: return False - return True + return 'id' in auth_details and auth_details.get('type') == AuthenticationType.SYSTEM_USER.value def handle(self, request, token: str, get_token_details): version, get_key = Cache_Version.TOKEN.value diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py index 9582fbdb9..bd749f345 100644 --- a/apps/common/constants/permission_constants.py +++ b/apps/common/constants/permission_constants.py @@ -170,6 +170,8 @@ class RoleConstants(Enum): ADMIN = Role("ADMIN", '超级管理员', RoleGroup.SYSTEM_USER) WORKSPACE_MANAGE = Role("WORKSPACE_MANAGE", '工作空间管理员', RoleGroup.SYSTEM_USER) USER = Role("USER", '普通用户', RoleGroup.SYSTEM_USER) + CHAT_ANONYMOUS_USER = Role("CHAT_ANONYMOUS_USER", "对话匿名用户", RoleGroup.CHAT_USER) + CHAT_USER = Role("CHAT_USER", "对话用户", RoleGroup.CHAT_USER) def get_workspace_role(self): return lambda r, kwargs: Role(name=self.value.name, @@ -901,7 +903,7 @@ class Auth: """ def __init__(self, - current_role_list: List[Role], + current_role_list: List[RoleConstants | Role], permission_list: List[PermissionConstants | Permission], **keywords): # 权限列表 diff --git a/apps/maxkb/settings/auth.py b/apps/maxkb/settings/auth.py index 2b0fc7e1d..739536b1f 100644 --- a/apps/maxkb/settings/auth.py +++ b/apps/maxkb/settings/auth.py @@ -7,7 +7,9 @@ @desc: """ USER_TOKEN_AUTH = 'common.auth.handle.impl.user_token.UserToken' +CHAT_ANONYMOUS_USER_AURH = 'common.auth.handle.impl.chat_anonymous_user_token.ChatAnonymousUserToken' AUTH_HANDLES = [ USER_TOKEN_AUTH, + CHAT_ANONYMOUS_USER_AURH ] diff --git a/apps/maxkb/settings/base.py b/apps/maxkb/settings/base.py index 478c92c02..64c9c8b40 100644 --- a/apps/maxkb/settings/base.py +++ b/apps/maxkb/settings/base.py @@ -46,6 +46,7 @@ INSTALLED_APPS = [ 'models_provider', 'django_celery_beat', 'application', + 'chat' 'oss' ]