From 3ea41c3297a7a080affa1df3216b494fdfb1a4ef Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 30 Dec 2024 10:24:36 +0800 Subject: [PATCH] fix: Application access restrictions are not separately counted for each application (#1937) --- .../step/chat_step/impl/base_chat_step.py | 16 +++++---- apps/application/flow/i_step_node.py | 4 ++- ...onpublicaccessclient_client_id_and_more.py | 34 +++++++++++++++++++ apps/application/models/api_key_model.py | 6 +++- .../serializers/chat_message_serializers.py | 11 +++--- 5 files changed, 58 insertions(+), 13 deletions(-) create mode 100644 apps/application/migrations/0021_applicationpublicaccessclient_client_id_and_more.py diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 1aaf7b56c..6adff00c4 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -28,9 +28,11 @@ from common.constants.authentication_type import AuthenticationType from setting.models_provider.tools import get_model_instance_by_model_user_id -def add_access_num(client_id=None, client_type=None): - if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: - application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=client_id).first() +def add_access_num(client_id=None, client_type=None, application_id=None): + if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None: + application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id, + application_id=application_id) + .first()) if application_public_access_client is not None: application_public_access_client.access_num = application_public_access_client.access_num + 1 application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1 @@ -90,14 +92,14 @@ def event_content(response, request_token, response_token, {'node_is_end': True, 'view_type': 'many_view', 'node_type': 'ai-chat-node'}) - add_access_num(client_id, client_type) + add_access_num(client_id, client_type, manage.context.get('application_id')) except Exception as e: logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') all_text = '异常' + str(e) write_context(step, manage, 0, 0, all_text) post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, all_text, manage, step, padding_problem_text, client_id) - add_access_num(client_id, client_type) + add_access_num(client_id, client_type, manage.context.get('application_id')) yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), all_text, 'ai-chat-node', [], True, 0, 0, @@ -241,7 +243,7 @@ class BaseChatStep(IChatStep): write_context(self, manage, request_token, response_token, chat_result.content) post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, chat_result.content, manage, self, padding_problem_text, client_id) - add_access_num(client_id, client_type) + add_access_num(client_id, client_type, manage.context.get('application_id')) return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), chat_result.content, True, request_token, response_token) @@ -250,6 +252,6 @@ class BaseChatStep(IChatStep): write_context(self, manage, 0, 0, all_text) post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, all_text, manage, self, padding_problem_text, client_id) - add_access_num(client_id, client_type) + add_access_num(client_id, client_type, manage.context.get('application_id')) return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), all_text, True, 0, 0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index e4279b949..211fd4e2e 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -87,7 +87,9 @@ class WorkFlowPostHandler: chat_cache.set(chat_id, self.chat_info, timeout=60 * 30) if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: - application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.client_id).first() + application_public_access_client = (QuerySet(ApplicationPublicAccessClient) + .filter(client_id=self.client_id, + application_id=self.chat_info.application.id).first()) if application_public_access_client is not None: application_public_access_client.access_num = application_public_access_client.access_num + 1 application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1 diff --git a/apps/application/migrations/0021_applicationpublicaccessclient_client_id_and_more.py b/apps/application/migrations/0021_applicationpublicaccessclient_client_id_and_more.py new file mode 100644 index 000000000..356ff4dff --- /dev/null +++ b/apps/application/migrations/0021_applicationpublicaccessclient_client_id_and_more.py @@ -0,0 +1,34 @@ +# Generated by Django 4.2.15 on 2024-12-27 18:42 + +from django.db import migrations, models +import uuid + +run_sql = """ +UPDATE application_public_access_client +SET client_id="id" +""" + + +class Migration(migrations.Migration): + dependencies = [ + ('application', '0020_application_record_update_time'), + ] + + operations = [ + migrations.AddField( + model_name='applicationpublicaccessclient', + name='client_id', + field=models.UUIDField(default=uuid.uuid1, verbose_name='公共访问链接客户端id'), + ), + migrations.AlterField( + model_name='applicationpublicaccessclient', + name='id', + field=models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, + verbose_name='主键id'), + ), + migrations.AddIndex( + model_name='applicationpublicaccessclient', + index=models.Index(fields=['client_id'], name='application_client__4de9af_idx'), + ), + migrations.RunSQL(run_sql) + ] diff --git a/apps/application/models/api_key_model.py b/apps/application/models/api_key_model.py index 965e1f1c4..6b515de86 100644 --- a/apps/application/models/api_key_model.py +++ b/apps/application/models/api_key_model.py @@ -50,10 +50,14 @@ class ApplicationAccessToken(AppModelMixin): class ApplicationPublicAccessClient(AppModelMixin): - id = models.UUIDField(max_length=128, primary_key=True, verbose_name="公共访问链接客户端id") + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + client_id = models.UUIDField(max_length=128, default=uuid.uuid1, verbose_name="公共访问链接客户端id") application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id") access_num = models.IntegerField(default=0, verbose_name="访问总次数次数") intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数") class Meta: db_table = "application_public_access_client" + indexes = [ + models.Index(fields=['client_id']), + ] diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 3e779c965..04d2d99ba 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -6,8 +6,8 @@ @date:2023/11/14 13:51 @desc: """ -from datetime import datetime import uuid +from datetime import datetime from typing import List, Dict from uuid import UUID @@ -107,7 +107,8 @@ class ChatInfo: 'search_mode': self.application.dataset_setting.get( 'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding', 'no_references_setting': self.get_no_references_setting(self.application.dataset_setting, model_setting), - 'user_id': self.application.user_id + 'user_id': self.application.user_id, + 'application_id': self.application.id } def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler, @@ -258,9 +259,11 @@ class ChatMessageSerializer(serializers.Serializer): def is_valid_intraday_access_num(self): if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: - access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first() + access_client = QuerySet(ApplicationPublicAccessClient).filter(client_id=self.data.get('client_id'), + application_id=self.data.get( + 'application_id')).first() if access_client is None: - access_client = ApplicationPublicAccessClient(id=self.data.get('client_id'), + access_client = ApplicationPublicAccessClient(client_id=self.data.get('client_id'), application_id=self.data.get('application_id'), access_num=0, intraday_access_num=0)