From 3807cf19600ba3fb6fea8bb77a47810fb45d64b4 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 9 Jun 2025 16:18:43 +0800 Subject: [PATCH] feat: application chat (#3213) --- .../chat_pipeline/I_base_chat_pipeline.py | 12 +- .../step/chat_step/i_chat_step.py | 9 +- .../step/chat_step/impl/base_chat_step.py | 35 +- .../impl/base_generate_human_message_step.py | 2 +- .../i_search_dataset_step.py | 8 +- .../impl/base_search_dataset_step.py | 21 +- apps/application/flow/i_step_node.py | 24 +- ...application_applica_f89647_idx_and_more.py | 56 +++ ...stats_applicationchatuserstats_and_more.py | 32 ++ apps/application/models/application_chat.py | 20 +- apps/application/serializers/common.py | 137 +++++++ apps/chat/api/chat_api.py | 29 ++ apps/chat/api/chat_authentication_api.py | 41 ++- apps/chat/serializers/chat.py | 346 ++++++++++++++++++ apps/chat/serializers/chat_authentication.py | 74 ++-- apps/chat/urls.py | 7 +- apps/chat/views/chat.py | 76 +++- apps/common/auth/common.py | 14 +- .../handle/impl/chat_anonymous_user_token.py | 11 +- apps/common/constants/cache_version.py | 3 + apps/common/constants/permission_constants.py | 16 + 21 files changed, 835 insertions(+), 138 deletions(-) create mode 100644 apps/application/migrations/0004_remove_applicationchatclientstats_application_applica_f89647_idx_and_more.py create mode 100644 apps/application/migrations/0005_rename_applicationchatclientstats_applicationchatuserstats_and_more.py create mode 100644 apps/application/serializers/common.py create mode 100644 apps/chat/api/chat_api.py create mode 100644 apps/chat/serializers/chat.py diff --git a/apps/application/chat_pipeline/I_base_chat_pipeline.py b/apps/application/chat_pipeline/I_base_chat_pipeline.py index a35bdc39c..8ef4896f3 100644 --- a/apps/application/chat_pipeline/I_base_chat_pipeline.py +++ b/apps/application/chat_pipeline/I_base_chat_pipeline.py @@ -12,17 +12,17 @@ from typing import Type from rest_framework import serializers -from dataset.models import Paragraph +from knowledge.models import Paragraph class ParagraphPipelineModel: - def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str, + def __init__(self, _id: str, document_id: str, knowledge_id: str, content: str, title: str, status: str, is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str, hit_handling_method: str, directly_return_similarity: float, meta: dict = None): self.id = _id self.document_id = document_id - self.dataset_id = dataset_id + self.knowledge_id = knowledge_id self.content = content self.title = title self.status = status, @@ -39,7 +39,7 @@ class ParagraphPipelineModel: return { 'id': self.id, 'document_id': self.document_id, - 'dataset_id': self.dataset_id, + 'knowledge_id': self.knowledge_id, 'content': self.content, 'title': self.title, 'status': self.status, @@ -66,7 +66,7 @@ class ParagraphPipelineModel: if isinstance(paragraph, Paragraph): self.paragraph = {'id': paragraph.id, 'document_id': paragraph.document_id, - 'dataset_id': paragraph.dataset_id, + 'knowledge_id': paragraph.knowledge_id, 'content': paragraph.content, 'title': paragraph.title, 'status': paragraph.status, @@ -106,7 +106,7 @@ class ParagraphPipelineModel: def build(self): return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')), - str(self.paragraph.get('dataset_id')), + str(self.paragraph.get('knowledge_id')), self.paragraph.get('content'), self.paragraph.get('title'), self.paragraph.get('status'), self.paragraph.get('is_active'), diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py index 4abda510d..1451cd88e 100644 --- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -44,7 +44,7 @@ class PostResponseHandler: @abstractmethod def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str, answer_text, - manage, step, padding_problem_text: str = None, client_id=None, **kwargs): + manage, step, padding_problem_text: str = None, **kwargs): pass @@ -68,8 +68,9 @@ class IChatStep(IBaseChatPipelineStep): label=_("Completion Question")) # 是否使用流的形式输出 stream = serializers.BooleanField(required=False, label=_("Streaming Output")) - client_id = serializers.CharField(required=True, label=_("Client id")) - client_type = serializers.CharField(required=True, label=_("Client Type")) + chat_user_id = serializers.CharField(required=True, label=_("Chat user id")) + + chat_user_type = serializers.CharField(required=True, label=_("Chat user Type")) # 未查询到引用分段 no_references_setting = NoReferencesSetting(required=True, label=_("No reference segment settings")) @@ -104,6 +105,6 @@ class IChatStep(IBaseChatPipelineStep): user_id: str = None, paragraph_list=None, manage: PipelineManage = None, - padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, + padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None, no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs): pass diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index 4e5142ac2..13cab54e2 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -25,15 +25,16 @@ from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineMode from application.chat_pipeline.pipeline_manage import PipelineManage from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler from application.flow.tools import Reasoning -from application.models.application_api_key import ApplicationPublicAccessClient -from common.constants.authentication_type import AuthenticationType +from application.models import ApplicationChatUserStats, ChatUserType from models_provider.tools import get_model_instance_by_model_user_id -def add_access_num(client_id=None, client_type=None, application_id=None): - if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None: - application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id, - application_id=application_id) +def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None): + if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__( + chat_user_type) and application_id is not None: + application_public_access_client = (QuerySet(ApplicationChatUserStats).filter(chat_user_id=chat_user_id, + chat_user_type=chat_user_type, + application_id=application_id) .first()) if application_public_access_client is not None: application_public_access_client.access_num = application_public_access_client.access_num + 1 @@ -124,11 +125,9 @@ def event_content(response, request_token = 0 response_token = 0 write_context(step, manage, request_token, response_token, all_text) - asker = manage.context.get('form_data', {}).get('asker', None) post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, - all_text, manage, step, padding_problem_text, client_id, - reasoning_content=reasoning_content if reasoning_content_enable else '' - , asker=asker) + all_text, manage, step, padding_problem_text, + reasoning_content=reasoning_content if reasoning_content_enable else '') yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', [], '', True, request_token, response_token, @@ -139,10 +138,8 @@ def event_content(response, logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') all_text = 'Exception:' + str(e) write_context(step, manage, 0, 0, all_text) - asker = manage.context.get('form_data', {}).get('asker', None) post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, - all_text, manage, step, padding_problem_text, client_id, reasoning_content='', - asker=asker) + all_text, manage, step, padding_problem_text, reasoning_content='') add_access_num(client_id, client_type, manage.context.get('application_id')) yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node', [], all_text, @@ -165,7 +162,7 @@ class BaseChatStep(IChatStep): manage: PipelineManage = None, padding_problem_text: str = None, stream: bool = True, - client_id=None, client_type=None, + chat_user_id=None, chat_user_type=None, no_references_setting=None, model_params_setting=None, model_setting=None, @@ -175,12 +172,13 @@ class BaseChatStep(IChatStep): if stream: return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, - manage, padding_problem_text, client_id, client_type, no_references_setting, + manage, padding_problem_text, chat_user_id, chat_user_type, + no_references_setting, model_setting) else: return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, - manage, padding_problem_text, client_id, client_type, no_references_setting, + manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting, model_setting) def get_details(self, manage, **kwargs): @@ -235,7 +233,7 @@ class BaseChatStep(IChatStep): paragraph_list=None, manage: PipelineManage = None, padding_problem_text: str = None, - client_id=None, client_type=None, + chat_user_id=None, chat_user_type=None, no_references_setting=None, model_setting=None): chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, @@ -244,7 +242,8 @@ class BaseChatStep(IChatStep): r = StreamingHttpResponse( streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, post_response_handler, manage, self, chat_model, message_list, problem_text, - padding_problem_text, client_id, client_type, is_ai_chat, model_setting), + padding_problem_text, chat_user_id, chat_user_type, is_ai_chat, + model_setting), content_type='text/event-stream;charset=utf-8') r['Cache-Control'] = 'no-cache' diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py index 68cfbbcb9..edb4d9a72 100644 --- a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py +++ b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py @@ -15,7 +15,7 @@ from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineMode from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \ IGenerateHumanMessageStep from application.models import ChatRecord -from common.util.split_model import flat_map +from common.utils.common import flat_map class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep): diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py index f086e72be..bb08ca9e2 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -26,8 +26,8 @@ class ISearchDatasetStep(IBaseChatPipelineStep): padding_problem_text = serializers.CharField(required=False, label=_("System completes question text")) # 需要查询的数据集id列表 - dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), - label=_("Dataset id list")) + knowledge_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + label=_("Dataset id list")) # 需要排除的文档id exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), label=_("List of document ids to exclude")) @@ -55,7 +55,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): self.context['paragraph_list'] = paragraph_list @abstractmethod - def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], + def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, search_mode: str = None, user_id=None, @@ -65,7 +65,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep): :param similarity: 相关性 :param top_n: 查询多少条 :param problem_text: 用户问题 - :param dataset_id_list: 需要查询的数据集id列表 + :param knowledge_id_list: 需要查询的数据集id列表 :param exclude_document_id_list: 需要排除的文档id :param exclude_paragraph_id_list: 需要排除段落id :param padding_problem_text 补全问题 diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py index b7f069785..aacd36ed8 100644 --- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py +++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -35,42 +35,33 @@ def get_model_by_id(_id, user_id): return model -def get_embedding_id(dataset_id_list): -<<<<<<< Updated upstream:apps/chat/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py - dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) - if len(set([dataset.embedding_model_id for dataset in dataset_list])) > 1: - raise Exception(_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled.")) - if len(dataset_list) == 0: - raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base")) - return dataset_list[0].embedding_model_id -======= - knowledge_list = QuerySet(Knowledge).filter(id__in=dataset_id_list) +def get_embedding_id(knowledge_id_list): + knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list) if len(set([knowledge.embedding_mode_id for knowledge in knowledge_list])) > 1: raise Exception( _("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled.")) if len(knowledge_list) == 0: raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base")) return knowledge_list[0].embedding_mode_id ->>>>>>> Stashed changes:apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py class BaseSearchDatasetStep(ISearchDatasetStep): - def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], + def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str], exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, search_mode: str = None, user_id=None, **kwargs) -> List[ParagraphPipelineModel]: - if len(dataset_id_list) == 0: + if len(knowledge_id_list) == 0: return [] exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text - model_id = get_embedding_id(dataset_id_list) + model_id = get_embedding_id(knowledge_id_list) model = get_model_by_id(model_id, user_id) self.context['model_name'] = model.name embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) embedding_value = embedding_model.embed_query(exec_problem_text) vector = VectorStore.get_embedding_vector() - embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list, + embedding_list = vector.query(exec_problem_text, embedding_value, knowledge_id_list, exclude_document_id_list, exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode)) if embedding_list is None: return [] diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 701d5d78d..2fd61408a 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -18,8 +18,8 @@ from rest_framework import serializers from rest_framework.exceptions import ValidationError, ErrorDetail from application.flow.common import Answer, NodeChunk -from application.models import ChatRecord -from application.models import ApplicationChatClientStats +from application.models import ChatRecord, ChatUserType +from application.models import ApplicationChatUserStats from common.constants.authentication_type import AuthenticationType from common.field.common import InstanceField @@ -45,10 +45,10 @@ def is_interrupt(node, step_variable: Dict, global_variable: Dict): class WorkFlowPostHandler: - def __init__(self, chat_info, client_id, client_type): + def __init__(self, chat_info, chat_user_id, chat_user_type): self.chat_info = chat_info - self.client_id = client_id - self.client_type = client_type + self.chat_user_id = chat_user_id + self.chat_user_type = chat_user_type def handler(self, chat_id, chat_record_id, @@ -84,13 +84,13 @@ class WorkFlowPostHandler: run_time=time.time() - workflow.context['start_time'], index=0) asker = workflow.context.get('asker', None) - self.chat_info.append_chat_record(chat_record, self.client_id, asker) - # 重新设置缓存 - chat_cache.set(chat_id, - self.chat_info, timeout=60 * 30) - if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: - application_public_access_client = (QuerySet(ApplicationChatClientStats) - .filter(client_id=self.client_id, + self.chat_info.append_chat_record(chat_record) + self.chat_info.set_cahce() + if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__( + self.chat_user_type): + application_public_access_client = (QuerySet(ApplicationChatUserStats) + .filter(chat_user_id=self.chat_user_id, + chat_user_type=self.chat_user_type, application_id=self.chat_info.application.id).first()) if application_public_access_client is not None: application_public_access_client.access_num = application_public_access_client.access_num + 1 diff --git a/apps/application/migrations/0004_remove_applicationchatclientstats_application_applica_f89647_idx_and_more.py b/apps/application/migrations/0004_remove_applicationchatclientstats_application_applica_f89647_idx_and_more.py new file mode 100644 index 000000000..c9346dcbe --- /dev/null +++ b/apps/application/migrations/0004_remove_applicationchatclientstats_application_applica_f89647_idx_and_more.py @@ -0,0 +1,56 @@ +# Generated by Django 5.2 on 2025-06-09 05:55 + +import uuid +import uuid_utils.compat +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0003_applicationaccesstoken_show_exec_chat_client_type_and_more'), + ] + + operations = [ + migrations.RemoveIndex( + model_name='applicationchatclientstats', + name='application_applica_f89647_idx', + ), + migrations.RenameField( + model_name='chat', + old_name='client_id', + new_name='chat_user_id', + ), + migrations.RenameField( + model_name='chat', + old_name='client_type', + new_name='chat_user_type', + ), + migrations.RemoveField( + model_name='applicationchatclientstats', + name='client_id', + ), + migrations.RemoveField( + model_name='applicationchatclientstats', + name='client_type', + ), + migrations.AddField( + model_name='applicationchatclientstats', + name='chat_user_id', + field=models.UUIDField(default=uuid_utils.compat.uuid7, verbose_name='对话用户id'), + ), + migrations.AddField( + model_name='applicationchatclientstats', + name='chat_user_type', + field=models.CharField(choices=[('ANONYMOUS_USER', '匿名用户'), ('CHAT_USER', '对话用户'), ('SYSTEM_API_KEY', '系统API_KEY'), ('APPLICATION_API_KEY', '应用API_KEY')], default='ANONYMOUS_USER', max_length=64, verbose_name='对话用户类型'), + ), + migrations.AlterField( + model_name='chat', + name='id', + field=models.UUIDField(default=uuid.UUID('01975341-b4e8-7d52-913b-1bb67d7d8107'), editable=False, primary_key=True, serialize=False, verbose_name='主键id'), + ), + migrations.AddIndex( + model_name='applicationchatclientstats', + index=models.Index(fields=['application_id', 'chat_user_id'], name='application_applica_23b4d2_idx'), + ), + ] diff --git a/apps/application/migrations/0005_rename_applicationchatclientstats_applicationchatuserstats_and_more.py b/apps/application/migrations/0005_rename_applicationchatclientstats_applicationchatuserstats_and_more.py new file mode 100644 index 000000000..9962ee48b --- /dev/null +++ b/apps/application/migrations/0005_rename_applicationchatclientstats_applicationchatuserstats_and_more.py @@ -0,0 +1,32 @@ +# Generated by Django 5.2 on 2025-06-09 07:31 + +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0004_remove_applicationchatclientstats_application_applica_f89647_idx_and_more'), + ] + + operations = [ + migrations.RenameModel( + old_name='ApplicationChatClientStats', + new_name='ApplicationChatUserStats', + ), + migrations.RenameIndex( + model_name='applicationchatuserstats', + new_name='application_applica_1652ba_idx', + old_name='application_applica_23b4d2_idx', + ), + migrations.AlterField( + model_name='chat', + name='id', + field=models.UUIDField(default=uuid.UUID('01975399-efa5-7dc3-8f97-edc67332ed24'), editable=False, primary_key=True, serialize=False, verbose_name='主键id'), + ), + migrations.AlterModelTable( + name='applicationchatuserstats', + table='application_chat_user_stats', + ), + ] diff --git a/apps/application/models/application_chat.py b/apps/application/models/application_chat.py index 325551ebb..8c2a22633 100644 --- a/apps/application/models/application_chat.py +++ b/apps/application/models/application_chat.py @@ -17,7 +17,7 @@ from common.encoder.encoder import SystemEncoder from common.mixins.app_model_mixin import AppModelMixin -class ClientType(models.TextChoices): +class ChatUserType(models.TextChoices): ANONYMOUS_USER = "ANONYMOUS_USER", '匿名用户' CHAT_USER = "CHAT_USER", "对话用户" SYSTEM_API_KEY = "SYSTEM_API_KEY", "系统API_KEY" @@ -28,9 +28,9 @@ class Chat(AppModelMixin): id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7(), editable=False, verbose_name="主键id") application = models.ForeignKey(Application, on_delete=models.CASCADE) abstract = models.CharField(max_length=1024, verbose_name="摘要") - client_id = models.UUIDField(verbose_name="客户端id", default=None, null=True) - client_type = models.CharField(max_length=64, verbose_name="客户端类型", choices=ClientType.choices, - default=ClientType.ANONYMOUS_USER) + chat_user_id = models.UUIDField(verbose_name="客户端id", default=None, null=True) + chat_user_type = models.CharField(max_length=64, verbose_name="客户端类型", choices=ChatUserType.choices, + default=ChatUserType.ANONYMOUS_USER) is_deleted = models.BooleanField(verbose_name="逻辑删除", default=False) class Meta: @@ -86,17 +86,17 @@ class ChatRecord(AppModelMixin): db_table = "application_chat_record" -class ApplicationChatClientStats(AppModelMixin): +class ApplicationChatUserStats(AppModelMixin): id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id") - client_id = models.UUIDField(max_length=128, default=uuid.uuid7, verbose_name="公共访问链接客户端id") - client_type = models.CharField(max_length=64, verbose_name="客户端类型", choices=ClientType.choices, - default=ClientType.ANONYMOUS_USER) + chat_user_id = models.UUIDField(max_length=128, default=uuid.uuid7, verbose_name="对话用户id") + chat_user_type = models.CharField(max_length=64, verbose_name="对话用户类型", choices=ChatUserType.choices, + default=ChatUserType.ANONYMOUS_USER) application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id") access_num = models.IntegerField(default=0, verbose_name="访问总次数次数") intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数") class Meta: - db_table = "application_chat_client_stats" + db_table = "application_chat_user_stats" indexes = [ - models.Index(fields=['application_id', 'client_id']), + models.Index(fields=['application_id', 'chat_user_id']), ] diff --git a/apps/application/serializers/common.py b/apps/application/serializers/common.py new file mode 100644 index 000000000..8af009077 --- /dev/null +++ b/apps/application/serializers/common.py @@ -0,0 +1,137 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: common.py + @date:2025/6/9 13:42 + @desc: +""" +from datetime import datetime +from typing import List + +from django.core.cache import cache +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ + +from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler +from application.models import Application, WorkFlowVersion, ChatRecord, Chat +from common.constants.cache_version import Cache_Version +from models_provider.models import Model +from models_provider.tools import get_model_credential + + +class ChatInfo: + def __init__(self, + chat_id: str, + chat_user_id: str, + chat_user_type: str, + knowledge_id_list: List[str], + exclude_document_id_list: list[str], + application: Application, + work_flow_version: WorkFlowVersion = None): + """ + :param chat_id: 对话id + :param chat_user_id 对话用户id + :param chat_user_type 对话用户类型 + :param knowledge_id_list: 知识库列表 + :param exclude_document_id_list: 排除的文档 + :param application: 应用信息 + """ + self.chat_id = chat_id + self.chat_user_id = chat_user_id + self.chat_user_type = chat_user_type + self.application = application + self.knowledge_id_list = knowledge_id_list + self.exclude_document_id_list = exclude_document_id_list + self.chat_record_list: List[ChatRecord] = [] + self.work_flow_version = work_flow_version + + @staticmethod + def get_no_references_setting(knowledge_setting, model_setting): + no_references_setting = knowledge_setting.get( + 'no_references_setting', { + 'status': 'ai_questioning', + 'value': '{question}'}) + if no_references_setting.get('status') == 'ai_questioning': + no_references_prompt = model_setting.get('no_references_prompt', '{question}') + no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}" + return no_references_setting + + def to_base_pipeline_manage_params(self): + knowledge_setting = self.application.knowledge_setting + model_setting = self.application.model_setting + model_id = self.application.model.id if self.application.model is not None else None + model_params_setting = None + if model_id is not None: + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = credential.get_model_params_setting_form(model.model_name).get_default_form_data() + return { + 'knowledge_id_list': self.knowledge_id_list, + 'exclude_document_id_list': self.exclude_document_id_list, + 'exclude_paragraph_id_list': [], + 'top_n': knowledge_setting.get('top_n') or 3, + 'similarity': knowledge_setting.get('similarity') or 0.6, + 'max_paragraph_char_number': knowledge_setting.get('max_paragraph_char_number') or 5000, + 'history_chat_record': self.chat_record_list, + 'chat_id': self.chat_id, + 'dialogue_number': self.application.dialogue_number, + 'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len( + self.application.problem_optimization_prompt) > 0 else _( + "() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the tag"), + 'prompt': model_setting.get( + 'prompt') if 'prompt' in model_setting and len(model_setting.get( + 'prompt')) > 0 else Application.get_default_model_prompt(), + 'system': model_setting.get( + 'system', None), + 'model_id': model_id, + 'problem_optimization': self.application.problem_optimization, + 'stream': True, + 'model_setting': model_setting, + 'model_params_setting': model_params_setting if self.application.model_params_setting is None or len( + self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting, + 'search_mode': self.application.knowledge_setting.get('search_mode') or 'embedding', + 'no_references_setting': self.get_no_references_setting(self.application.knowledge_setting, model_setting), + 'user_id': self.application.user_id, + 'application_id': self.application.id + } + + def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler, + exclude_paragraph_id_list, chat_user_id: str, chat_user_type, stream=True, + form_data=None): + if form_data is None: + form_data = {} + params = self.to_base_pipeline_manage_params() + return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler, + 'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'chat_user_id': chat_user_id, + 'chat_user_type': chat_user_type, 'form_data': form_data} + + def append_chat_record(self, chat_record: ChatRecord): + chat_record.problem_text = chat_record.problem_text[0:10240] if chat_record.problem_text is not None else "" + chat_record.answer_text = chat_record.answer_text[0:40960] if chat_record.problem_text is not None else "" + is_save = True + # 存入缓存中 + for index in range(len(self.chat_record_list)): + record = self.chat_record_list[index] + if record.id == chat_record.id: + self.chat_record_list[index] = chat_record + is_save = False + if is_save: + self.chat_record_list.append(chat_record) + cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(), + timeout=60 * 30) + if self.application.id is not None: + Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024], + chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type).save() + else: + QuerySet(Chat).filter(id=self.chat_id).update(update_time=datetime.now()) + # 插入会话记录 + chat_record.save() + + def set_cache(self): + cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(), + timeout=60 * 30) + + @staticmethod + def get_cache(chat_id): + return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version()) diff --git a/apps/chat/api/chat_api.py b/apps/chat/api/chat_api.py new file mode 100644 index 000000000..69b45da32 --- /dev/null +++ b/apps/chat/api/chat_api.py @@ -0,0 +1,29 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: chat_api.py + @date:2025/6/9 15:23 + @desc: +""" +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter + +from chat.serializers.chat import ChatMessageSerializers +from common.mixins.api_mixin import APIMixin + + +class ChatAPI(APIMixin): + @staticmethod + def get_parameters(): + return [OpenApiParameter( + name="chat_id", + description="对话id", + type=OpenApiTypes.STR, + location='path', + required=True, + )] + + @staticmethod + def get_request(): + return ChatMessageSerializers diff --git a/apps/chat/api/chat_authentication_api.py b/apps/chat/api/chat_authentication_api.py index cae9006a5..89fbd1450 100644 --- a/apps/chat/api/chat_authentication_api.py +++ b/apps/chat/api/chat_authentication_api.py @@ -6,14 +6,19 @@ @date:2025/6/6 19:59 @desc: """ -from chat.serializers.chat_authentication import AuthenticationSerializer + +from django.utils.translation import gettext_lazy as _ +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter + +from chat.serializers.chat_authentication import AnonymousAuthenticationSerializer from common.mixins.api_mixin import APIMixin class ChatAuthenticationAPI(APIMixin): @staticmethod def get_request(): - return AuthenticationSerializer() + return AnonymousAuthenticationSerializer @staticmethod def get_parameters(): @@ -22,3 +27,35 @@ class ChatAuthenticationAPI(APIMixin): @staticmethod def get_response(): pass + + +class ChatAuthenticationProfileAPI(APIMixin): + + @staticmethod + def get_parameters(): + return [OpenApiParameter( + name="access_token", + description=_("access_token"), + type=OpenApiTypes.STR, + location='query', + required=True, + )] + + +class ChatOpenAPI(APIMixin): + @staticmethod + def get_parameters(): + return [OpenApiParameter( + name="workspace_id", + description="工作空间id", + type=OpenApiTypes.STR, + location='path', + required=True, + ), + OpenApiParameter( + name="application_id", + description="应用id", + type=OpenApiTypes.STR, + location='path', + required=True, + )] diff --git a/apps/chat/serializers/chat.py b/apps/chat/serializers/chat.py new file mode 100644 index 000000000..54bc79aa1 --- /dev/null +++ b/apps/chat/serializers/chat.py @@ -0,0 +1,346 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: chat.py + @date:2025/6/9 11:23 + @desc: +""" + +from gettext import gettext +from typing import List + +import uuid_utils.compat as uuid +from django.db.models import QuerySet +from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers + +from application.chat_pipeline.pipeline_manage import PipelineManage +from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler +from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep +from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \ + BaseGenerateHumanMessageStep +from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep +from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep +from application.flow.common import Answer +from application.flow.i_step_node import WorkFlowPostHandler +from application.flow.workflow_manage import WorkflowManage, Flow +from application.models import Application, ApplicationTypeChoices, WorkFlowVersion, ApplicationKnowledgeMapping, \ + ChatUserType, ApplicationChatUserStats, ApplicationAccessToken, ChatRecord, Chat +from application.serializers.common import ChatInfo +from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed, ChatException +from common.handle.base_to_response import BaseToResponse +from common.handle.impl.response.system_to_response import SystemToResponse +from common.utils.common import flat_map +from knowledge.models import Document, Paragraph +from models_provider.models import Model, Status + + +class ChatMessageSerializers(serializers.Serializer): + message = serializers.CharField(required=True, label=_("User Questions")) + stream = serializers.BooleanField(required=True, + label=_("Is the answer in streaming mode")) + re_chat = serializers.BooleanField(required=True, label=_("Do you want to reply again")) + chat_record_id = serializers.UUIDField(required=False, allow_null=True, + label=_("Conversation record id")) + + node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + label=_("Node id")) + + runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + label=_("Runtime node id")) + + node_data = serializers.DictField(required=False, allow_null=True, + label=_("Node parameters")) + + form_data = serializers.DictField(required=False, label=_("Global variables")) + image_list = serializers.ListField(required=False, label=_("picture")) + document_list = serializers.ListField(required=False, label=_("document")) + audio_list = serializers.ListField(required=False, label=_("Audio")) + other_list = serializers.ListField(required=False, label=_("Other")) + child_node = serializers.DictField(required=False, allow_null=True, + label=_("Child Nodes")) + + +def get_post_handler(chat_info: ChatInfo): + class PostHandler(PostResponseHandler): + + def handler(self, + chat_id, + chat_record_id, + paragraph_list: List[Paragraph], + problem_text: str, + answer_text, + manage: PipelineManage, + step: BaseChatStep, + padding_problem_text: str = None, + **kwargs): + answer_list = [[Answer(answer_text, 'ai-chat-node', 'ai-chat-node', 'ai-chat-node', {}, 'ai-chat-node', + kwargs.get('reasoning_content', '')).to_dict()]] + chat_record = ChatRecord(id=chat_record_id, + chat_id=chat_id, + problem_text=problem_text, + answer_text=answer_text, + details=manage.get_details(), + message_tokens=manage.context['message_tokens'], + answer_tokens=manage.context['answer_tokens'], + answer_text_list=answer_list, + run_time=manage.context['run_time'], + index=len(chat_info.chat_record_list) + 1) + chat_info.append_chat_record(chat_record) + # 重新设置缓存 + chat_info.set_cache() + + return PostHandler() + + +class ChatSerializers(serializers.Serializer): + chat_id = serializers.UUIDField(required=True, label=_("Conversation ID")) + chat_user_id = serializers.CharField(required=True, label=_("Client id")) + chat_user_type = serializers.CharField(required=True, label=_("Client Type")) + application_id = serializers.UUIDField(required=True, allow_null=True, + label=_("Application ID")) + + def is_valid_application_workflow(self, *, raise_exception=False): + self.is_valid_intraday_access_num() + + def is_valid_chat_id(self, chat_info: ChatInfo): + if self.data.get('application_id') is not None and self.data.get('application_id') != str( + chat_info.application.id): + raise ChatException(500, _("Conversation does not exist")) + + def is_valid_intraday_access_num(self): + if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__( + self.data.get('chat_user_type')): + access_client = QuerySet(ApplicationChatUserStats).filter(chat_user_id=self.data.get('chat_user_id'), + application_id=self.data.get( + 'application_id')).first() + if access_client is None: + access_client = ApplicationChatUserStats(chat_user_id=self.data.get('chat_user_id'), + chat_user_type=self.data.get('chat_user_type'), + application_id=self.data.get('application_id'), + access_num=0, + intraday_access_num=0) + access_client.save() + + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=self.data.get('application_id')).first() + if application_access_token.access_num <= access_client.intraday_access_num: + raise AppChatNumOutOfBoundsFailed(1002, _("The number of visits exceeds today's visits")) + + def is_valid_application_simple(self, *, chat_info: ChatInfo, raise_exception=False): + self.is_valid_intraday_access_num() + model = chat_info.application.model + if model is None: + return chat_info + model = QuerySet(Model).filter(id=model.id).first() + if model is None: + return chat_info + if model.status == Status.ERROR: + raise ChatException(500, _("The current model is not available")) + if model.status == Status.DOWNLOAD: + raise ChatException(500, _("The model is downloading, please try again later")) + return chat_info + + def chat_simple(self, chat_info: ChatInfo, instance, base_to_response): + message = instance.get('message') + re_chat = instance.get('re_chat') + stream = instance.get('stream') + chat_user_id = self.data.get('chat_user_id') + chat_user_type = self.data.get('chat_user_type') + form_data = instance.get("form_data") + pipeline_manage_builder = PipelineManage.builder() + # 如果开启了问题优化,则添加上问题优化步骤 + if chat_info.application.problem_optimization: + pipeline_manage_builder.append_step(BaseResetProblemStep) + # 构建流水线管理器 + pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep) + .append_step(BaseGenerateHumanMessageStep) + .append_step(BaseChatStep) + .add_base_to_response(base_to_response) + .build()) + exclude_paragraph_id_list = [] + # 相同问题是否需要排除已经查询到的段落 + if re_chat: + paragraph_id_list = flat_map( + [[paragraph.get('id') for paragraph in chat_record.details['search_step']['paragraph_list']] for + chat_record in chat_info.chat_record_list if + chat_record.problem_text == message and 'search_step' in chat_record.details and 'paragraph_list' in + chat_record.details['search_step']]) + exclude_paragraph_id_list = list(set(paragraph_id_list)) + # 构建运行参数 + params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list, + chat_user_id, chat_user_type, stream, form_data) + # 运行流水线作业 + pipeline_message.run(params) + return pipeline_message.context['chat_result'] + + @staticmethod + def get_chat_record(chat_info, chat_record_id): + if chat_info is not None: + chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if + str(chat_record.id) == str(chat_record_id)] + if chat_record_list is not None and len(chat_record_list): + return chat_record_list[-1] + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_info.chat_id).first() + if chat_record is None: + raise ChatException(500, _("Conversation record does not exist")) + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id).first() + return chat_record + + def chat_work_flow(self, chat_info: ChatInfo, instance: dict, base_to_response): + message = self.data.get('message') + re_chat = self.data.get('re_chat') + stream = self.data.get('stream') + chat_user_id = instance.get('chat_user_id') + chat_user_type = instance.get('chat_user_type') + form_data = self.data.get('form_data') + image_list = self.data.get('image_list') + document_list = self.data.get('document_list') + audio_list = self.data.get('audio_list') + other_list = self.data.get('other_list') + user_id = chat_info.application.user_id + chat_record_id = self.data.get('chat_record_id') + chat_record = None + history_chat_record = chat_info.chat_record_list + if chat_record_id is not None: + chat_record = self.get_chat_record(chat_info, chat_record_id) + history_chat_record = [r for r in chat_info.chat_record_list if str(r.id) != chat_record_id] + work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), + {'history_chat_record': history_chat_record, 'question': message, + 'chat_id': chat_info.chat_id, 'chat_record_id': str( + uuid.uuid1()) if chat_record is None else chat_record.id, + 'stream': stream, + 're_chat': re_chat, + 'chat_user_id': chat_user_id, + 'chat_user_type': chat_user_type, + 'user_id': user_id}, + WorkFlowPostHandler(chat_info, chat_user_id, chat_user_type), + base_to_response, form_data, image_list, document_list, audio_list, + other_list, + self.data.get('runtime_node_id'), + self.data.get('node_data'), chat_record, self.data.get('child_node')) + r = work_flow_manage.run() + return r + + def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToResponse()): + super().is_valid(raise_exception=True) + ChatMessageSerializers(data=instance).is_valid(raise_exception=True) + chat_info = self.get_chat_info() + self.is_valid_chat_id(chat_info) + if chat_info.application.type == ApplicationTypeChoices.SIMPLE: + self.is_valid_application_simple(raise_exception=True, chat_info=chat_info), + return self.chat_simple(chat_info, instance, base_to_response) + else: + self.is_valid_application_workflow(raise_exception=True) + return self.chat_work_flow(chat_info, instance, base_to_response) + + def get_chat_info(self): + self.is_valid(raise_exception=True) + chat_id = self.data.get('chat_id') + chat_info: ChatInfo = ChatInfo.get_cache(chat_id) + if chat_info is None: + chat_info: ChatInfo = self.re_open_chat(chat_id) + chat_info.set_cache() + return chat_info + + def re_open_chat(self, chat_id: str): + chat = QuerySet(Chat).filter(id=chat_id).first() + if chat is None: + raise ChatException(500, _("Conversation does not exist")) + application = QuerySet(Application).filter(id=chat.application_id).first() + if application is None: + raise ChatException(500, _("Application does not exist")) + if application.type == ApplicationTypeChoices.SIMPLE: + return self.re_open_chat_simple(chat_id, application) + else: + return self.re_open_chat_work_flow(chat_id, application) + + def re_open_chat_simple(self, chat_id, application): + # 数据集id列表 + knowledge_id_list = [str(row.dataset_id) for row in + QuerySet(ApplicationKnowledgeMapping).filter( + application_id=application.id)] + + # 需要排除的文档 + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + knowledge_id__in=knowledge_id_list, + is_active=False)] + chat_info = ChatInfo(chat_id, self.data.get('chat_user_id'), self.data.get('chat_user_type'), knowledge_id_list, + exclude_document_id_list, application) + chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5]) + chat_record_list.sort(key=lambda r: r.create_time) + for chat_record in chat_record_list: + chat_info.chat_record_list.append(chat_record) + return chat_info + + def re_open_chat_work_flow(self, chat_id, application): + work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application.id).order_by( + '-create_time')[0:1].first() + if work_flow_version is None: + raise ChatException(500, _("The application has not been published. Please use it after publishing.")) + + chat_info = ChatInfo(chat_id, self.data.get('chat_user_id'), self.data.get('chat_user_type'), [], [], + application, work_flow_version) + chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5]) + chat_record_list.sort(key=lambda r: r.create_time) + for chat_record in chat_record_list: + chat_info.chat_record_list.append(chat_record) + return chat_info + + +class OpenChatSerializers(serializers.Serializer): + workspace_id = serializers.CharField(required=True) + application_id = serializers.UUIDField(required=True) + chat_user_id = serializers.CharField(required=True, label=_("Client id")) + chat_user_type = serializers.CharField(required=True, label=_("Client Type")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + workspace_id = self.data.get('workspace_id') + application_id = self.data.get('application_id') + if not QuerySet(Application).filter(id=application_id, workspace_id=workspace_id).exists(): + raise AppApiException(500, gettext('Application does not exist')) + + def open(self): + self.is_valid(raise_exception=True) + application_id = self.data.get('application_id') + application = QuerySet(Application).get(id=application_id) + if application.type == ApplicationTypeChoices.SIMPLE: + return self.open_simple(application) + else: + return self.open_work_flow(application) + + def open_work_flow(self, application): + self.is_valid(raise_exception=True) + application_id = self.data.get('application_id') + chat_user_id = self.data.get("chat_user_id") + chat_user_type = self.data.get("chat_user_type") + chat_id = str(uuid.uuid7()) + work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application_id).order_by( + '-create_time')[0:1].first() + if work_flow_version is None: + raise AppApiException(500, + gettext( + "The application has not been published. Please use it after publishing.")) + ChatInfo(chat_id, chat_user_id, chat_user_type, [], + [], + application, work_flow_version).set_cache() + return chat_id + + def open_simple(self, application): + application_id = self.data.get('application_id') + chat_user_id = self.data.get("chat_user_id") + chat_user_type = self.data.get("chat_user_type") + knowledge_id_list = [str(row.dataset_id) for row in + QuerySet(ApplicationKnowledgeMapping).filter( + application_id=application_id)] + chat_id = str(uuid.uuid7()) + ChatInfo(chat_id, chat_user_id, chat_user_type, knowledge_id_list, + [str(document.id) for document in + QuerySet(Document).filter( + knowledge_id__in=knowledge_id_list, + is_active=False)], + application).set_cache() + return chat_id diff --git a/apps/chat/serializers/chat_authentication.py b/apps/chat/serializers/chat_authentication.py index 92dfab92c..bbcd57846 100644 --- a/apps/chat/serializers/chat_authentication.py +++ b/apps/chat/serializers/chat_authentication.py @@ -14,43 +14,18 @@ from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ from rest_framework import serializers -from application.models import ApplicationAccessToken, ClientType, Application, ApplicationTypeChoices, WorkFlowVersion +from application.models import ApplicationAccessToken, ChatUserType, Application, ApplicationTypeChoices, \ + WorkFlowVersion from application.serializers.application import ApplicationSerializerModel from common.auth.common import ChatUserToken, ChatAuthentication - from common.constants.authentication_type import AuthenticationType from common.constants.cache_version import Cache_Version from common.database_model_manage.database_model_manage import DatabaseModelManage -from common.exception.app_exception import NotFound404, AppApiException, AppUnauthorizedFailed +from common.exception.app_exception import NotFound404, AppUnauthorizedFailed -def auth(application_id, access_token, authentication_value, token_details): - client_id = token_details.get('client_id') - if client_id is None: - client_id = str(uuid.uuid1()) - _type = AuthenticationType.CHAT_ANONYMOUS_USER - if authentication_value is not None: - application_setting_model = DatabaseModelManage.get_model('application_setting') - if application_setting_model is not None: - application_setting = QuerySet(application_setting_model).filter(application_id=application_id).first() - if application_setting.authentication: - auth_type = application_setting.authentication_value.get('type') - auth_value = authentication_value.get(auth_type + '_value') - if auth_type == 'password': - if authentication_value.get('type') == 'password': - if auth_value == authentication_value.get(auth_type + '_value'): - return ChatUserToken(application_id, None, access_token, _type, ClientType.ANONYMOUS_USER, - client_id, ChatAuthentication(auth_type, True, True)) - else: - raise AppApiException(500, '认证方式不匹配') - return ChatUserToken(application_id, None, access_token, _type, ClientType.ANONYMOUS_USER, - client_id, ChatAuthentication(None, False, False)) - - -class AuthenticationSerializer(serializers.Serializer): +class AnonymousAuthenticationSerializer(serializers.Serializer): access_token = serializers.CharField(required=True, label=_("access_token")) - authentication_value = serializers.JSONField(required=False, allow_null=True, - label=_("Certification Information")) def auth(self, request, with_valid=True): token = request.META.get('HTTP_AUTHORIZATION') @@ -65,16 +40,42 @@ class AuthenticationSerializer(serializers.Serializer): self.is_valid(raise_exception=True) access_token = self.data.get("access_token") application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() - authentication_value = self.data.get('authentication_value', None) if application_access_token is not None and application_access_token.is_active: - chat_user_token = auth(application_access_token.application_id, access_token, authentication_value, - token_details) - - return chat_user_token.to_token() + chat_user_id = token_details.get('chat_user_id') or str(uuid.uuid1()) + _type = AuthenticationType.CHAT_ANONYMOUS_USER + return ChatUserToken(application_access_token.application_id, None, access_token, _type, + ChatUserType.ANONYMOUS_USER, + chat_user_id, ChatAuthentication(None, False, False)).to_token() else: raise NotFound404(404, _("Invalid access_token")) +class AuthProfileSerializer(serializers.Serializer): + access_token = serializers.CharField(required=True, label=_("access_token")) + + def profile(self): + self.is_valid(raise_exception=True) + access_token = self.data.get("access_token") + application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() + application_id = application_access_token.application_id + profile = { + 'authentication': False + } + application_setting_model = DatabaseModelManage.get_model('application_setting') + if application_setting_model: + application_setting = QuerySet(application_setting_model).filter(application_id=application_id).first() + profile = { + 'icon': application_setting.application.icon, + 'application_name': application_setting.application.name, + 'bg_icon': application_setting.chat_background, + 'authentication': application_setting.authentication, + 'authentication_type': application_setting.authentication_value.get( + 'type', 'password'), + 'login_value': application_setting.authentication_value.get('login_value', []) + } + return profile + + class ApplicationProfileSerializer(serializers.Serializer): application_id = serializers.UUIDField(required=True, label=_("Application ID")) @@ -119,11 +120,6 @@ class ApplicationProfileSerializer(serializers.Serializer): 'avatar': application_setting.avatar, 'show_avatar': application_setting.show_avatar, 'float_icon': application_setting.float_icon, - 'authentication': application_setting.authentication, - 'authentication_type': application_setting.authentication_value.get( - 'type', 'password'), - 'login_value': application_setting.authentication_value.get( - 'login_value', []), 'disclaimer': application_setting.disclaimer, 'disclaimer_value': application_setting.disclaimer_value, 'custom_theme': application_setting.custom_theme, diff --git a/apps/chat/urls.py b/apps/chat/urls.py index 0f0bc12a2..8550572b6 100644 --- a/apps/chat/urls.py +++ b/apps/chat/urls.py @@ -6,6 +6,9 @@ app_name = 'chat' urlpatterns = [ path('chat/embed', views.ChatEmbedView.as_view()), - path('application/authentication', views.Authentication.as_view()), - path('profile', views.ApplicationProfile.as_view()) + path('application/anonymous_authentication', views.AnonymousAuthentication.as_view()), + path('auth/profile', views.AuthProfile.as_view()), + path('profile', views.ApplicationProfile.as_view()), + path('chat_message/', views.ChatView.as_view()), + path('workspace//application//open', views.OpenView.as_view()) ] diff --git a/apps/chat/views/chat.py b/apps/chat/views/chat.py index bda65408e..ed2dd71ff 100644 --- a/apps/chat/views/chat.py +++ b/apps/chat/views/chat.py @@ -12,14 +12,18 @@ from drf_spectacular.utils import extend_schema from rest_framework.request import Request from rest_framework.views import APIView -from chat.api.chat_authentication_api import ChatAuthenticationAPI -from chat.serializers.chat_authentication import AuthenticationSerializer, ApplicationProfileSerializer +from chat.api.chat_api import ChatAPI +from chat.api.chat_authentication_api import ChatAuthenticationAPI, ChatAuthenticationProfileAPI, ChatOpenAPI +from chat.serializers.chat import OpenChatSerializers, ChatSerializers +from chat.serializers.chat_authentication import AnonymousAuthenticationSerializer, ApplicationProfileSerializer, \ + AuthProfileSerializer from common.auth import TokenAuth +from common.constants.permission_constants import ChatAuth from common.exception.app_exception import AppAuthenticationFailed from common.result import result -class Authentication(APIView): +class AnonymousAuthentication(APIView): def options(self, request, *args, **kwargs): return HttpResponse( headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", @@ -28,18 +32,16 @@ class Authentication(APIView): @extend_schema( methods=['POST'], - description=_('Application Certification'), - summary=_('Application Certification'), - operation_id=_('Application Certification'), # type: ignore + description=_('Application Anonymous Certification'), + summary=_('Application Anonymous Certification'), + operation_id=_('Application Anonymous Certification'), # type: ignore request=ChatAuthenticationAPI.get_request(), responses=None, tags=[_('Chat')] # type: ignore ) def post(self, request: Request): return result.success( - AuthenticationSerializer(data={'access_token': request.data.get("access_token"), - 'authentication_value': request.data.get( - 'authentication_value')}).auth( + AnonymousAuthenticationSerializer(data={'access_token': request.data.get("access_token")}).auth( request), headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", "Access-Control-Allow-Methods": "POST", @@ -60,11 +62,61 @@ class ApplicationProfile(APIView): tags=[_('Chat')] # type: ignore ) def get(self, request: Request): - if 'application_id' in request.auth.keywords: + if isinstance(request.auth, ChatAuth): return result.success(ApplicationProfileSerializer( - data={'application_id': request.auth.keywords.get('application_id')}).profile()) + data={'application_id': request.auth.application_id}).profile()) raise AppAuthenticationFailed(401, "身份异常") +class AuthProfile(APIView): + @extend_schema( + methods=['GET'], + description=_("Get application authentication information"), + summary=_("Get application authentication information"), + operation_id=_("Get application authentication information"), # type: ignore + parameters=ChatAuthenticationProfileAPI.get_parameters(), + responses=None, + tags=[_('Chat')] # type: ignore + ) + def get(self, request: Request): + return result.success( + AuthProfileSerializer(data={'access_token': request.query_params.get("access_token")}).profile()) + + class ChatView(APIView): - pass + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['POST'], + description=_("dialogue"), + summary=_("dialogue"), + operation_id=_("dialogue"), # type: ignore + request=ChatAPI.get_request(), + parameters=ChatAPI.get_parameters(), + responses=None, + tags=[_('Chat')] # type: ignore + ) + def post(self, request: Request, chat_id: str): + return ChatSerializers(data={'chat_id': chat_id, + 'chat_user_id': request.auth.chat_user_id, + 'chat_user_type': request.auth.chat_user_type, + 'application_id': request.auth.application_id} + ).chat(request.data) + + +class OpenView(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['GET'], + description=_("Get the session id according to the application id"), + summary=_("Get the session id according to the application id"), + operation_id=_("Get the session id according to the application id"), # type: ignore + parameters=ChatOpenAPI.get_parameters(), + responses=None, + tags=[_('Chat')] # type: ignore + ) + def get(self, request: Request, workspace_id: str, application_id: str): + return result.success(OpenChatSerializers( + data={'workspace_id': workspace_id, 'application_id': application_id, + 'chat_user_id': request.auth.chat_user_id, 'chat_user_type': request.auth.chat_user_type}).open()) diff --git a/apps/common/auth/common.py b/apps/common/auth/common.py index 700ed12b4..3b37f28e2 100644 --- a/apps/common/auth/common.py +++ b/apps/common/auth/common.py @@ -32,14 +32,14 @@ class ChatAuthentication: class ChatUserToken: - def __init__(self, application_id, user_id, access_token, _type, client_type, client_id, + def __init__(self, application_id, user_id, access_token, _type, chat_user_type, chat_user_id, authentication: ChatAuthentication): self.application_id = application_id self.user_id = user_id, self.access_token = access_token self.type = _type - self.client_type = client_type - self.client_id = client_id + self.chat_user_type = chat_user_type + self.chat_user_id = chat_user_id self.authentication = authentication def to_dict(self): @@ -48,8 +48,8 @@ class ChatUserToken: 'user_id': str(self.user_id), 'access_token': self.access_token, 'type': str(self.type.value), - 'client_type': str(self.client_type), - 'client_id': str(self.client_id), + 'chat_user_type': str(self.chat_user_type), + 'chat_user_id': str(self.chat_user_id), 'authentication': self.authentication.to_string() } @@ -59,6 +59,6 @@ class ChatUserToken: @staticmethod def new_instance(token_dict): return ChatUserToken(token_dict.get('application_id'), token_dict.get('user_id'), - token_dict.get('access_token'), token_dict.get('type'), token_dict.get('client_type'), - token_dict.get('client_id'), + token_dict.get('access_token'), token_dict.get('type'), token_dict.get('chat_user_type'), + token_dict.get('chat_user_id'), ChatAuthentication.new_instance(token_dict.get('authentication'))) diff --git a/apps/common/auth/handle/impl/chat_anonymous_user_token.py b/apps/common/auth/handle/impl/chat_anonymous_user_token.py index 0d8da19ce..2e00314f1 100644 --- a/apps/common/auth/handle/impl/chat_anonymous_user_token.py +++ b/apps/common/auth/handle/impl/chat_anonymous_user_token.py @@ -9,12 +9,11 @@ from django.db.models import QuerySet from django.utils.translation import gettext_lazy as _ -from application.models import ApplicationAccessToken, ClientType +from application.models import ApplicationAccessToken from common.auth.common import ChatUserToken - from common.auth.handle.auth_base_handle import AuthBaseHandle from common.constants.authentication_type import AuthenticationType -from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, Auth +from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, ChatAuth from common.exception.app_exception import AppAuthenticationFailed, ChatException @@ -45,11 +44,11 @@ class ChatAnonymousUserToken(AuthBaseHandle): if request.path != '/api/application/profile': if chat_user_token.authentication.is_auth and not chat_user_token.authentication.auth_passed: raise ChatException(1002, _('Authentication information is incorrect')) - return None, Auth( + return None, ChatAuth( current_role_list=[RoleConstants.CHAT_ANONYMOUS_USER], permission_list=[ Permission(group=Group.APPLICATION, operate=Operate.USE)], application_id=application_access_token.application_id, - client_id=auth_details.get('client_id'), - client_type=ClientType.ANONYMOUS_USER) + chat_user_id=chat_user_token.chat_user_id, + chat_user_type=chat_user_token.chat_user_type) diff --git a/apps/common/constants/cache_version.py b/apps/common/constants/cache_version.py index 6b7b74ad8..b30c22414 100644 --- a/apps/common/constants/cache_version.py +++ b/apps/common/constants/cache_version.py @@ -27,6 +27,9 @@ class Cache_Version(Enum): # 应用对接三方应用的缓存 APPLICATION_THIRD_PARTY = "APPLICATION:THIRD_PARTY", lambda key: key + # 对话 + CHAT = "CHAT", lambda key: key + def get_version(self): return self.value[0] diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py index 2617472ce..bdf81942c 100644 --- a/apps/common/constants/permission_constants.py +++ b/apps/common/constants/permission_constants.py @@ -925,6 +925,22 @@ def get_permission_list_by_resource_group(resource_group: ResourcePermissionGrou PermissionConstants[k].value.resource_permission_group_list.__contains__(resource_group)] +class ChatAuth: + def __init__(self, + current_role_list: List[RoleConstants | Role], + permission_list: List[PermissionConstants | Permission], + chat_user_id, + chat_user_type, + application_id): + # 权限列表 + self.permission_list = permission_list + # 角色列表 + self.role_list = current_role_list + self.chat_user_id = chat_user_id + self.chat_user_type = chat_user_type + self.application_id = application_id + + class Auth: """ 用于存储当前用户的角色和权限