diff --git a/apps/application/chat_pipeline/I_base_chat_pipeline.py b/apps/application/chat_pipeline/I_base_chat_pipeline.py
index a35bdc39c..8ef4896f3 100644
--- a/apps/application/chat_pipeline/I_base_chat_pipeline.py
+++ b/apps/application/chat_pipeline/I_base_chat_pipeline.py
@@ -12,17 +12,17 @@ from typing import Type
from rest_framework import serializers
-from dataset.models import Paragraph
+from knowledge.models import Paragraph
class ParagraphPipelineModel:
- def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str,
+ def __init__(self, _id: str, document_id: str, knowledge_id: str, content: str, title: str, status: str,
is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str,
hit_handling_method: str, directly_return_similarity: float, meta: dict = None):
self.id = _id
self.document_id = document_id
- self.dataset_id = dataset_id
+ self.knowledge_id = knowledge_id
self.content = content
self.title = title
self.status = status,
@@ -39,7 +39,7 @@ class ParagraphPipelineModel:
return {
'id': self.id,
'document_id': self.document_id,
- 'dataset_id': self.dataset_id,
+ 'knowledge_id': self.knowledge_id,
'content': self.content,
'title': self.title,
'status': self.status,
@@ -66,7 +66,7 @@ class ParagraphPipelineModel:
if isinstance(paragraph, Paragraph):
self.paragraph = {'id': paragraph.id,
'document_id': paragraph.document_id,
- 'dataset_id': paragraph.dataset_id,
+ 'knowledge_id': paragraph.knowledge_id,
'content': paragraph.content,
'title': paragraph.title,
'status': paragraph.status,
@@ -106,7 +106,7 @@ class ParagraphPipelineModel:
def build(self):
return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')),
- str(self.paragraph.get('dataset_id')),
+ str(self.paragraph.get('knowledge_id')),
self.paragraph.get('content'), self.paragraph.get('title'),
self.paragraph.get('status'),
self.paragraph.get('is_active'),
diff --git a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py
index 4abda510d..1451cd88e 100644
--- a/apps/application/chat_pipeline/step/chat_step/i_chat_step.py
+++ b/apps/application/chat_pipeline/step/chat_step/i_chat_step.py
@@ -44,7 +44,7 @@ class PostResponseHandler:
@abstractmethod
def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str,
answer_text,
- manage, step, padding_problem_text: str = None, client_id=None, **kwargs):
+ manage, step, padding_problem_text: str = None, **kwargs):
pass
@@ -68,8 +68,9 @@ class IChatStep(IBaseChatPipelineStep):
label=_("Completion Question"))
# 是否使用流的形式输出
stream = serializers.BooleanField(required=False, label=_("Streaming Output"))
- client_id = serializers.CharField(required=True, label=_("Client id"))
- client_type = serializers.CharField(required=True, label=_("Client Type"))
+ chat_user_id = serializers.CharField(required=True, label=_("Chat user id"))
+
+ chat_user_type = serializers.CharField(required=True, label=_("Chat user Type"))
# 未查询到引用分段
no_references_setting = NoReferencesSetting(required=True,
label=_("No reference segment settings"))
@@ -104,6 +105,6 @@ class IChatStep(IBaseChatPipelineStep):
user_id: str = None,
paragraph_list=None,
manage: PipelineManage = None,
- padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None,
+ padding_problem_text: str = None, stream: bool = True, chat_user_id=None, chat_user_type=None,
no_references_setting=None, model_params_setting=None, model_setting=None, **kwargs):
pass
diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
index 4e5142ac2..13cab54e2 100644
--- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
+++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py
@@ -25,15 +25,16 @@ from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineMode
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.flow.tools import Reasoning
-from application.models.application_api_key import ApplicationPublicAccessClient
-from common.constants.authentication_type import AuthenticationType
+from application.models import ApplicationChatUserStats, ChatUserType
from models_provider.tools import get_model_instance_by_model_user_id
-def add_access_num(client_id=None, client_type=None, application_id=None):
- if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value and application_id is not None:
- application_public_access_client = (QuerySet(ApplicationPublicAccessClient).filter(client_id=client_id,
- application_id=application_id)
+def add_access_num(chat_user_id=None, chat_user_type=None, application_id=None):
+ if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
+ chat_user_type) and application_id is not None:
+ application_public_access_client = (QuerySet(ApplicationChatUserStats).filter(chat_user_id=chat_user_id,
+ chat_user_type=chat_user_type,
+ application_id=application_id)
.first())
if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1
@@ -124,11 +125,9 @@ def event_content(response,
request_token = 0
response_token = 0
write_context(step, manage, request_token, response_token, all_text)
- asker = manage.context.get('form_data', {}).get('asker', None)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
- all_text, manage, step, padding_problem_text, client_id,
- reasoning_content=reasoning_content if reasoning_content_enable else ''
- , asker=asker)
+ all_text, manage, step, padding_problem_text,
+ reasoning_content=reasoning_content if reasoning_content_enable else '')
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], '', True,
request_token, response_token,
@@ -139,10 +138,8 @@ def event_content(response,
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
all_text = 'Exception:' + str(e)
write_context(step, manage, 0, 0, all_text)
- asker = manage.context.get('form_data', {}).get('asker', None)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
- all_text, manage, step, padding_problem_text, client_id, reasoning_content='',
- asker=asker)
+ all_text, manage, step, padding_problem_text, reasoning_content='')
add_access_num(client_id, client_type, manage.context.get('application_id'))
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), 'ai-chat-node',
[], all_text,
@@ -165,7 +162,7 @@ class BaseChatStep(IChatStep):
manage: PipelineManage = None,
padding_problem_text: str = None,
stream: bool = True,
- client_id=None, client_type=None,
+ chat_user_id=None, chat_user_type=None,
no_references_setting=None,
model_params_setting=None,
model_setting=None,
@@ -175,12 +172,13 @@ class BaseChatStep(IChatStep):
if stream:
return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
- manage, padding_problem_text, client_id, client_type, no_references_setting,
+ manage, padding_problem_text, chat_user_id, chat_user_type,
+ no_references_setting,
model_setting)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
- manage, padding_problem_text, client_id, client_type, no_references_setting,
+ manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
model_setting)
def get_details(self, manage, **kwargs):
@@ -235,7 +233,7 @@ class BaseChatStep(IChatStep):
paragraph_list=None,
manage: PipelineManage = None,
padding_problem_text: str = None,
- client_id=None, client_type=None,
+ chat_user_id=None, chat_user_type=None,
no_references_setting=None,
model_setting=None):
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
@@ -244,7 +242,8 @@ class BaseChatStep(IChatStep):
r = StreamingHttpResponse(
streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list,
post_response_handler, manage, self, chat_model, message_list, problem_text,
- padding_problem_text, client_id, client_type, is_ai_chat, model_setting),
+ padding_problem_text, chat_user_id, chat_user_type, is_ai_chat,
+ model_setting),
content_type='text/event-stream;charset=utf-8')
r['Cache-Control'] = 'no-cache'
diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py
index 68cfbbcb9..edb4d9a72 100644
--- a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py
+++ b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py
@@ -15,7 +15,7 @@ from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineMode
from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \
IGenerateHumanMessageStep
from application.models import ChatRecord
-from common.util.split_model import flat_map
+from common.utils.common import flat_map
class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep):
diff --git a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py
index f086e72be..bb08ca9e2 100644
--- a/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py
+++ b/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py
@@ -26,8 +26,8 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
padding_problem_text = serializers.CharField(required=False,
label=_("System completes question text"))
# 需要查询的数据集id列表
- dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
- label=_("Dataset id list"))
+ knowledge_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
+ label=_("Dataset id list"))
# 需要排除的文档id
exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True),
label=_("List of document ids to exclude"))
@@ -55,7 +55,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
self.context['paragraph_list'] = paragraph_list
@abstractmethod
- def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
+ def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
@@ -65,7 +65,7 @@ class ISearchDatasetStep(IBaseChatPipelineStep):
:param similarity: 相关性
:param top_n: 查询多少条
:param problem_text: 用户问题
- :param dataset_id_list: 需要查询的数据集id列表
+ :param knowledge_id_list: 需要查询的数据集id列表
:param exclude_document_id_list: 需要排除的文档id
:param exclude_paragraph_id_list: 需要排除段落id
:param padding_problem_text 补全问题
diff --git a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
index b7f069785..aacd36ed8 100644
--- a/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
+++ b/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
@@ -35,42 +35,33 @@ def get_model_by_id(_id, user_id):
return model
-def get_embedding_id(dataset_id_list):
-<<<<<<< Updated upstream:apps/chat/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
- dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
- if len(set([dataset.embedding_model_id for dataset in dataset_list])) > 1:
- raise Exception(_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
- if len(dataset_list) == 0:
- raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
- return dataset_list[0].embedding_model_id
-=======
- knowledge_list = QuerySet(Knowledge).filter(id__in=dataset_id_list)
+def get_embedding_id(knowledge_id_list):
+ knowledge_list = QuerySet(Knowledge).filter(id__in=knowledge_id_list)
if len(set([knowledge.embedding_mode_id for knowledge in knowledge_list])) > 1:
raise Exception(
_("The vector model of the associated knowledge base is inconsistent and the segmentation cannot be recalled."))
if len(knowledge_list) == 0:
raise Exception(_("The knowledge base setting is wrong, please reset the knowledge base"))
return knowledge_list[0].embedding_mode_id
->>>>>>> Stashed changes:apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py
class BaseSearchDatasetStep(ISearchDatasetStep):
- def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
+ def execute(self, problem_text: str, knowledge_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
user_id=None,
**kwargs) -> List[ParagraphPipelineModel]:
- if len(dataset_id_list) == 0:
+ if len(knowledge_id_list) == 0:
return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
- model_id = get_embedding_id(dataset_id_list)
+ model_id = get_embedding_id(knowledge_id_list)
model = get_model_by_id(model_id, user_id)
self.context['model_name'] = model.name
embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()
- embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
+ embedding_list = vector.query(exec_problem_text, embedding_value, knowledge_id_list, exclude_document_id_list,
exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode))
if embedding_list is None:
return []
diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py
index 701d5d78d..2fd61408a 100644
--- a/apps/application/flow/i_step_node.py
+++ b/apps/application/flow/i_step_node.py
@@ -18,8 +18,8 @@ from rest_framework import serializers
from rest_framework.exceptions import ValidationError, ErrorDetail
from application.flow.common import Answer, NodeChunk
-from application.models import ChatRecord
-from application.models import ApplicationChatClientStats
+from application.models import ChatRecord, ChatUserType
+from application.models import ApplicationChatUserStats
from common.constants.authentication_type import AuthenticationType
from common.field.common import InstanceField
@@ -45,10 +45,10 @@ def is_interrupt(node, step_variable: Dict, global_variable: Dict):
class WorkFlowPostHandler:
- def __init__(self, chat_info, client_id, client_type):
+ def __init__(self, chat_info, chat_user_id, chat_user_type):
self.chat_info = chat_info
- self.client_id = client_id
- self.client_type = client_type
+ self.chat_user_id = chat_user_id
+ self.chat_user_type = chat_user_type
def handler(self, chat_id,
chat_record_id,
@@ -84,13 +84,13 @@ class WorkFlowPostHandler:
run_time=time.time() - workflow.context['start_time'],
index=0)
asker = workflow.context.get('asker', None)
- self.chat_info.append_chat_record(chat_record, self.client_id, asker)
- # 重新设置缓存
- chat_cache.set(chat_id,
- self.chat_info, timeout=60 * 30)
- if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
- application_public_access_client = (QuerySet(ApplicationChatClientStats)
- .filter(client_id=self.client_id,
+ self.chat_info.append_chat_record(chat_record)
+ self.chat_info.set_cahce()
+ if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
+ self.chat_user_type):
+ application_public_access_client = (QuerySet(ApplicationChatUserStats)
+ .filter(chat_user_id=self.chat_user_id,
+ chat_user_type=self.chat_user_type,
application_id=self.chat_info.application.id).first())
if application_public_access_client is not None:
application_public_access_client.access_num = application_public_access_client.access_num + 1
diff --git a/apps/application/migrations/0004_remove_applicationchatclientstats_application_applica_f89647_idx_and_more.py b/apps/application/migrations/0004_remove_applicationchatclientstats_application_applica_f89647_idx_and_more.py
new file mode 100644
index 000000000..c9346dcbe
--- /dev/null
+++ b/apps/application/migrations/0004_remove_applicationchatclientstats_application_applica_f89647_idx_and_more.py
@@ -0,0 +1,56 @@
+# Generated by Django 5.2 on 2025-06-09 05:55
+
+import uuid
+import uuid_utils.compat
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('application', '0003_applicationaccesstoken_show_exec_chat_client_type_and_more'),
+ ]
+
+ operations = [
+ migrations.RemoveIndex(
+ model_name='applicationchatclientstats',
+ name='application_applica_f89647_idx',
+ ),
+ migrations.RenameField(
+ model_name='chat',
+ old_name='client_id',
+ new_name='chat_user_id',
+ ),
+ migrations.RenameField(
+ model_name='chat',
+ old_name='client_type',
+ new_name='chat_user_type',
+ ),
+ migrations.RemoveField(
+ model_name='applicationchatclientstats',
+ name='client_id',
+ ),
+ migrations.RemoveField(
+ model_name='applicationchatclientstats',
+ name='client_type',
+ ),
+ migrations.AddField(
+ model_name='applicationchatclientstats',
+ name='chat_user_id',
+ field=models.UUIDField(default=uuid_utils.compat.uuid7, verbose_name='对话用户id'),
+ ),
+ migrations.AddField(
+ model_name='applicationchatclientstats',
+ name='chat_user_type',
+ field=models.CharField(choices=[('ANONYMOUS_USER', '匿名用户'), ('CHAT_USER', '对话用户'), ('SYSTEM_API_KEY', '系统API_KEY'), ('APPLICATION_API_KEY', '应用API_KEY')], default='ANONYMOUS_USER', max_length=64, verbose_name='对话用户类型'),
+ ),
+ migrations.AlterField(
+ model_name='chat',
+ name='id',
+ field=models.UUIDField(default=uuid.UUID('01975341-b4e8-7d52-913b-1bb67d7d8107'), editable=False, primary_key=True, serialize=False, verbose_name='主键id'),
+ ),
+ migrations.AddIndex(
+ model_name='applicationchatclientstats',
+ index=models.Index(fields=['application_id', 'chat_user_id'], name='application_applica_23b4d2_idx'),
+ ),
+ ]
diff --git a/apps/application/migrations/0005_rename_applicationchatclientstats_applicationchatuserstats_and_more.py b/apps/application/migrations/0005_rename_applicationchatclientstats_applicationchatuserstats_and_more.py
new file mode 100644
index 000000000..9962ee48b
--- /dev/null
+++ b/apps/application/migrations/0005_rename_applicationchatclientstats_applicationchatuserstats_and_more.py
@@ -0,0 +1,32 @@
+# Generated by Django 5.2 on 2025-06-09 07:31
+
+import uuid
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+
+ dependencies = [
+ ('application', '0004_remove_applicationchatclientstats_application_applica_f89647_idx_and_more'),
+ ]
+
+ operations = [
+ migrations.RenameModel(
+ old_name='ApplicationChatClientStats',
+ new_name='ApplicationChatUserStats',
+ ),
+ migrations.RenameIndex(
+ model_name='applicationchatuserstats',
+ new_name='application_applica_1652ba_idx',
+ old_name='application_applica_23b4d2_idx',
+ ),
+ migrations.AlterField(
+ model_name='chat',
+ name='id',
+ field=models.UUIDField(default=uuid.UUID('01975399-efa5-7dc3-8f97-edc67332ed24'), editable=False, primary_key=True, serialize=False, verbose_name='主键id'),
+ ),
+ migrations.AlterModelTable(
+ name='applicationchatuserstats',
+ table='application_chat_user_stats',
+ ),
+ ]
diff --git a/apps/application/models/application_chat.py b/apps/application/models/application_chat.py
index 325551ebb..8c2a22633 100644
--- a/apps/application/models/application_chat.py
+++ b/apps/application/models/application_chat.py
@@ -17,7 +17,7 @@ from common.encoder.encoder import SystemEncoder
from common.mixins.app_model_mixin import AppModelMixin
-class ClientType(models.TextChoices):
+class ChatUserType(models.TextChoices):
ANONYMOUS_USER = "ANONYMOUS_USER", '匿名用户'
CHAT_USER = "CHAT_USER", "对话用户"
SYSTEM_API_KEY = "SYSTEM_API_KEY", "系统API_KEY"
@@ -28,9 +28,9 @@ class Chat(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7(), editable=False, verbose_name="主键id")
application = models.ForeignKey(Application, on_delete=models.CASCADE)
abstract = models.CharField(max_length=1024, verbose_name="摘要")
- client_id = models.UUIDField(verbose_name="客户端id", default=None, null=True)
- client_type = models.CharField(max_length=64, verbose_name="客户端类型", choices=ClientType.choices,
- default=ClientType.ANONYMOUS_USER)
+ chat_user_id = models.UUIDField(verbose_name="客户端id", default=None, null=True)
+ chat_user_type = models.CharField(max_length=64, verbose_name="客户端类型", choices=ChatUserType.choices,
+ default=ChatUserType.ANONYMOUS_USER)
is_deleted = models.BooleanField(verbose_name="逻辑删除", default=False)
class Meta:
@@ -86,17 +86,17 @@ class ChatRecord(AppModelMixin):
db_table = "application_chat_record"
-class ApplicationChatClientStats(AppModelMixin):
+class ApplicationChatUserStats(AppModelMixin):
id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid7, editable=False, verbose_name="主键id")
- client_id = models.UUIDField(max_length=128, default=uuid.uuid7, verbose_name="公共访问链接客户端id")
- client_type = models.CharField(max_length=64, verbose_name="客户端类型", choices=ClientType.choices,
- default=ClientType.ANONYMOUS_USER)
+ chat_user_id = models.UUIDField(max_length=128, default=uuid.uuid7, verbose_name="对话用户id")
+ chat_user_type = models.CharField(max_length=64, verbose_name="对话用户类型", choices=ChatUserType.choices,
+ default=ChatUserType.ANONYMOUS_USER)
application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id")
access_num = models.IntegerField(default=0, verbose_name="访问总次数次数")
intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数")
class Meta:
- db_table = "application_chat_client_stats"
+ db_table = "application_chat_user_stats"
indexes = [
- models.Index(fields=['application_id', 'client_id']),
+ models.Index(fields=['application_id', 'chat_user_id']),
]
diff --git a/apps/application/serializers/common.py b/apps/application/serializers/common.py
new file mode 100644
index 000000000..8af009077
--- /dev/null
+++ b/apps/application/serializers/common.py
@@ -0,0 +1,137 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎虎
+ @file: common.py
+ @date:2025/6/9 13:42
+ @desc:
+"""
+from datetime import datetime
+from typing import List
+
+from django.core.cache import cache
+from django.db.models import QuerySet
+from django.utils.translation import gettext_lazy as _
+
+from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
+from application.models import Application, WorkFlowVersion, ChatRecord, Chat
+from common.constants.cache_version import Cache_Version
+from models_provider.models import Model
+from models_provider.tools import get_model_credential
+
+
+class ChatInfo:
+ def __init__(self,
+ chat_id: str,
+ chat_user_id: str,
+ chat_user_type: str,
+ knowledge_id_list: List[str],
+ exclude_document_id_list: list[str],
+ application: Application,
+ work_flow_version: WorkFlowVersion = None):
+ """
+ :param chat_id: 对话id
+ :param chat_user_id 对话用户id
+ :param chat_user_type 对话用户类型
+ :param knowledge_id_list: 知识库列表
+ :param exclude_document_id_list: 排除的文档
+ :param application: 应用信息
+ """
+ self.chat_id = chat_id
+ self.chat_user_id = chat_user_id
+ self.chat_user_type = chat_user_type
+ self.application = application
+ self.knowledge_id_list = knowledge_id_list
+ self.exclude_document_id_list = exclude_document_id_list
+ self.chat_record_list: List[ChatRecord] = []
+ self.work_flow_version = work_flow_version
+
+ @staticmethod
+ def get_no_references_setting(knowledge_setting, model_setting):
+ no_references_setting = knowledge_setting.get(
+ 'no_references_setting', {
+ 'status': 'ai_questioning',
+ 'value': '{question}'})
+ if no_references_setting.get('status') == 'ai_questioning':
+ no_references_prompt = model_setting.get('no_references_prompt', '{question}')
+ no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}"
+ return no_references_setting
+
+ def to_base_pipeline_manage_params(self):
+ knowledge_setting = self.application.knowledge_setting
+ model_setting = self.application.model_setting
+ model_id = self.application.model.id if self.application.model is not None else None
+ model_params_setting = None
+ if model_id is not None:
+ model = QuerySet(Model).filter(id=model_id).first()
+ credential = get_model_credential(model.provider, model.model_type, model.model_name)
+ model_params_setting = credential.get_model_params_setting_form(model.model_name).get_default_form_data()
+ return {
+ 'knowledge_id_list': self.knowledge_id_list,
+ 'exclude_document_id_list': self.exclude_document_id_list,
+ 'exclude_paragraph_id_list': [],
+ 'top_n': knowledge_setting.get('top_n') or 3,
+ 'similarity': knowledge_setting.get('similarity') or 0.6,
+ 'max_paragraph_char_number': knowledge_setting.get('max_paragraph_char_number') or 5000,
+ 'history_chat_record': self.chat_record_list,
+ 'chat_id': self.chat_id,
+ 'dialogue_number': self.application.dialogue_number,
+ 'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len(
+ self.application.problem_optimization_prompt) > 0 else _(
+ "() contains the user's question. Answer the guessed user's question based on the context ({question}) Requirement: Output a complete question and put it in the tag"),
+ 'prompt': model_setting.get(
+ 'prompt') if 'prompt' in model_setting and len(model_setting.get(
+ 'prompt')) > 0 else Application.get_default_model_prompt(),
+ 'system': model_setting.get(
+ 'system', None),
+ 'model_id': model_id,
+ 'problem_optimization': self.application.problem_optimization,
+ 'stream': True,
+ 'model_setting': model_setting,
+ 'model_params_setting': model_params_setting if self.application.model_params_setting is None or len(
+ self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting,
+ 'search_mode': self.application.knowledge_setting.get('search_mode') or 'embedding',
+ 'no_references_setting': self.get_no_references_setting(self.application.knowledge_setting, model_setting),
+ 'user_id': self.application.user_id,
+ 'application_id': self.application.id
+ }
+
+ def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler,
+ exclude_paragraph_id_list, chat_user_id: str, chat_user_type, stream=True,
+ form_data=None):
+ if form_data is None:
+ form_data = {}
+ params = self.to_base_pipeline_manage_params()
+ return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler,
+ 'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'chat_user_id': chat_user_id,
+ 'chat_user_type': chat_user_type, 'form_data': form_data}
+
+ def append_chat_record(self, chat_record: ChatRecord):
+ chat_record.problem_text = chat_record.problem_text[0:10240] if chat_record.problem_text is not None else ""
+ chat_record.answer_text = chat_record.answer_text[0:40960] if chat_record.problem_text is not None else ""
+ is_save = True
+ # 存入缓存中
+ for index in range(len(self.chat_record_list)):
+ record = self.chat_record_list[index]
+ if record.id == chat_record.id:
+ self.chat_record_list[index] = chat_record
+ is_save = False
+ if is_save:
+ self.chat_record_list.append(chat_record)
+ cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
+ timeout=60 * 30)
+ if self.application.id is not None:
+ Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024],
+ chat_user_id=self.chat_user_id, chat_user_type=self.chat_user_type).save()
+ else:
+ QuerySet(Chat).filter(id=self.chat_id).update(update_time=datetime.now())
+ # 插入会话记录
+ chat_record.save()
+
+ def set_cache(self):
+ cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
+ timeout=60 * 30)
+
+ @staticmethod
+ def get_cache(chat_id):
+ return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version())
diff --git a/apps/chat/api/chat_api.py b/apps/chat/api/chat_api.py
new file mode 100644
index 000000000..69b45da32
--- /dev/null
+++ b/apps/chat/api/chat_api.py
@@ -0,0 +1,29 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎虎
+ @file: chat_api.py
+ @date:2025/6/9 15:23
+ @desc:
+"""
+from drf_spectacular.types import OpenApiTypes
+from drf_spectacular.utils import OpenApiParameter
+
+from chat.serializers.chat import ChatMessageSerializers
+from common.mixins.api_mixin import APIMixin
+
+
+class ChatAPI(APIMixin):
+ @staticmethod
+ def get_parameters():
+ return [OpenApiParameter(
+ name="chat_id",
+ description="对话id",
+ type=OpenApiTypes.STR,
+ location='path',
+ required=True,
+ )]
+
+ @staticmethod
+ def get_request():
+ return ChatMessageSerializers
diff --git a/apps/chat/api/chat_authentication_api.py b/apps/chat/api/chat_authentication_api.py
index cae9006a5..89fbd1450 100644
--- a/apps/chat/api/chat_authentication_api.py
+++ b/apps/chat/api/chat_authentication_api.py
@@ -6,14 +6,19 @@
@date:2025/6/6 19:59
@desc:
"""
-from chat.serializers.chat_authentication import AuthenticationSerializer
+
+from django.utils.translation import gettext_lazy as _
+from drf_spectacular.types import OpenApiTypes
+from drf_spectacular.utils import OpenApiParameter
+
+from chat.serializers.chat_authentication import AnonymousAuthenticationSerializer
from common.mixins.api_mixin import APIMixin
class ChatAuthenticationAPI(APIMixin):
@staticmethod
def get_request():
- return AuthenticationSerializer()
+ return AnonymousAuthenticationSerializer
@staticmethod
def get_parameters():
@@ -22,3 +27,35 @@ class ChatAuthenticationAPI(APIMixin):
@staticmethod
def get_response():
pass
+
+
+class ChatAuthenticationProfileAPI(APIMixin):
+
+ @staticmethod
+ def get_parameters():
+ return [OpenApiParameter(
+ name="access_token",
+ description=_("access_token"),
+ type=OpenApiTypes.STR,
+ location='query',
+ required=True,
+ )]
+
+
+class ChatOpenAPI(APIMixin):
+ @staticmethod
+ def get_parameters():
+ return [OpenApiParameter(
+ name="workspace_id",
+ description="工作空间id",
+ type=OpenApiTypes.STR,
+ location='path',
+ required=True,
+ ),
+ OpenApiParameter(
+ name="application_id",
+ description="应用id",
+ type=OpenApiTypes.STR,
+ location='path',
+ required=True,
+ )]
diff --git a/apps/chat/serializers/chat.py b/apps/chat/serializers/chat.py
new file mode 100644
index 000000000..54bc79aa1
--- /dev/null
+++ b/apps/chat/serializers/chat.py
@@ -0,0 +1,346 @@
+# coding=utf-8
+"""
+ @project: MaxKB
+ @Author:虎虎
+ @file: chat.py
+ @date:2025/6/9 11:23
+ @desc:
+"""
+
+from gettext import gettext
+from typing import List
+
+import uuid_utils.compat as uuid
+from django.db.models import QuerySet
+from django.utils.translation import gettext_lazy as _
+from rest_framework import serializers
+
+from application.chat_pipeline.pipeline_manage import PipelineManage
+from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler
+from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep
+from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \
+ BaseGenerateHumanMessageStep
+from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep
+from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep
+from application.flow.common import Answer
+from application.flow.i_step_node import WorkFlowPostHandler
+from application.flow.workflow_manage import WorkflowManage, Flow
+from application.models import Application, ApplicationTypeChoices, WorkFlowVersion, ApplicationKnowledgeMapping, \
+ ChatUserType, ApplicationChatUserStats, ApplicationAccessToken, ChatRecord, Chat
+from application.serializers.common import ChatInfo
+from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed, ChatException
+from common.handle.base_to_response import BaseToResponse
+from common.handle.impl.response.system_to_response import SystemToResponse
+from common.utils.common import flat_map
+from knowledge.models import Document, Paragraph
+from models_provider.models import Model, Status
+
+
+class ChatMessageSerializers(serializers.Serializer):
+ message = serializers.CharField(required=True, label=_("User Questions"))
+ stream = serializers.BooleanField(required=True,
+ label=_("Is the answer in streaming mode"))
+ re_chat = serializers.BooleanField(required=True, label=_("Do you want to reply again"))
+ chat_record_id = serializers.UUIDField(required=False, allow_null=True,
+ label=_("Conversation record id"))
+
+ node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
+ label=_("Node id"))
+
+ runtime_node_id = serializers.CharField(required=False, allow_null=True, allow_blank=True,
+ label=_("Runtime node id"))
+
+ node_data = serializers.DictField(required=False, allow_null=True,
+ label=_("Node parameters"))
+
+ form_data = serializers.DictField(required=False, label=_("Global variables"))
+ image_list = serializers.ListField(required=False, label=_("picture"))
+ document_list = serializers.ListField(required=False, label=_("document"))
+ audio_list = serializers.ListField(required=False, label=_("Audio"))
+ other_list = serializers.ListField(required=False, label=_("Other"))
+ child_node = serializers.DictField(required=False, allow_null=True,
+ label=_("Child Nodes"))
+
+
+def get_post_handler(chat_info: ChatInfo):
+ class PostHandler(PostResponseHandler):
+
+ def handler(self,
+ chat_id,
+ chat_record_id,
+ paragraph_list: List[Paragraph],
+ problem_text: str,
+ answer_text,
+ manage: PipelineManage,
+ step: BaseChatStep,
+ padding_problem_text: str = None,
+ **kwargs):
+ answer_list = [[Answer(answer_text, 'ai-chat-node', 'ai-chat-node', 'ai-chat-node', {}, 'ai-chat-node',
+ kwargs.get('reasoning_content', '')).to_dict()]]
+ chat_record = ChatRecord(id=chat_record_id,
+ chat_id=chat_id,
+ problem_text=problem_text,
+ answer_text=answer_text,
+ details=manage.get_details(),
+ message_tokens=manage.context['message_tokens'],
+ answer_tokens=manage.context['answer_tokens'],
+ answer_text_list=answer_list,
+ run_time=manage.context['run_time'],
+ index=len(chat_info.chat_record_list) + 1)
+ chat_info.append_chat_record(chat_record)
+ # 重新设置缓存
+ chat_info.set_cache()
+
+ return PostHandler()
+
+
+class ChatSerializers(serializers.Serializer):
+ chat_id = serializers.UUIDField(required=True, label=_("Conversation ID"))
+ chat_user_id = serializers.CharField(required=True, label=_("Client id"))
+ chat_user_type = serializers.CharField(required=True, label=_("Client Type"))
+ application_id = serializers.UUIDField(required=True, allow_null=True,
+ label=_("Application ID"))
+
+ def is_valid_application_workflow(self, *, raise_exception=False):
+ self.is_valid_intraday_access_num()
+
+ def is_valid_chat_id(self, chat_info: ChatInfo):
+ if self.data.get('application_id') is not None and self.data.get('application_id') != str(
+ chat_info.application.id):
+ raise ChatException(500, _("Conversation does not exist"))
+
+ def is_valid_intraday_access_num(self):
+ if [ChatUserType.ANONYMOUS_USER.value, ChatUserType.CHAT_USER.value].__contains__(
+ self.data.get('chat_user_type')):
+ access_client = QuerySet(ApplicationChatUserStats).filter(chat_user_id=self.data.get('chat_user_id'),
+ application_id=self.data.get(
+ 'application_id')).first()
+ if access_client is None:
+ access_client = ApplicationChatUserStats(chat_user_id=self.data.get('chat_user_id'),
+ chat_user_type=self.data.get('chat_user_type'),
+ application_id=self.data.get('application_id'),
+ access_num=0,
+ intraday_access_num=0)
+ access_client.save()
+
+ application_access_token = QuerySet(ApplicationAccessToken).filter(
+ application_id=self.data.get('application_id')).first()
+ if application_access_token.access_num <= access_client.intraday_access_num:
+ raise AppChatNumOutOfBoundsFailed(1002, _("The number of visits exceeds today's visits"))
+
+ def is_valid_application_simple(self, *, chat_info: ChatInfo, raise_exception=False):
+ self.is_valid_intraday_access_num()
+ model = chat_info.application.model
+ if model is None:
+ return chat_info
+ model = QuerySet(Model).filter(id=model.id).first()
+ if model is None:
+ return chat_info
+ if model.status == Status.ERROR:
+ raise ChatException(500, _("The current model is not available"))
+ if model.status == Status.DOWNLOAD:
+ raise ChatException(500, _("The model is downloading, please try again later"))
+ return chat_info
+
+ def chat_simple(self, chat_info: ChatInfo, instance, base_to_response):
+ message = instance.get('message')
+ re_chat = instance.get('re_chat')
+ stream = instance.get('stream')
+ chat_user_id = self.data.get('chat_user_id')
+ chat_user_type = self.data.get('chat_user_type')
+ form_data = instance.get("form_data")
+ pipeline_manage_builder = PipelineManage.builder()
+ # 如果开启了问题优化,则添加上问题优化步骤
+ if chat_info.application.problem_optimization:
+ pipeline_manage_builder.append_step(BaseResetProblemStep)
+ # 构建流水线管理器
+ pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep)
+ .append_step(BaseGenerateHumanMessageStep)
+ .append_step(BaseChatStep)
+ .add_base_to_response(base_to_response)
+ .build())
+ exclude_paragraph_id_list = []
+ # 相同问题是否需要排除已经查询到的段落
+ if re_chat:
+ paragraph_id_list = flat_map(
+ [[paragraph.get('id') for paragraph in chat_record.details['search_step']['paragraph_list']] for
+ chat_record in chat_info.chat_record_list if
+ chat_record.problem_text == message and 'search_step' in chat_record.details and 'paragraph_list' in
+ chat_record.details['search_step']])
+ exclude_paragraph_id_list = list(set(paragraph_id_list))
+ # 构建运行参数
+ params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list,
+ chat_user_id, chat_user_type, stream, form_data)
+ # 运行流水线作业
+ pipeline_message.run(params)
+ return pipeline_message.context['chat_result']
+
+ @staticmethod
+ def get_chat_record(chat_info, chat_record_id):
+ if chat_info is not None:
+ chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
+ str(chat_record.id) == str(chat_record_id)]
+ if chat_record_list is not None and len(chat_record_list):
+ return chat_record_list[-1]
+ chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_info.chat_id).first()
+ if chat_record is None:
+ raise ChatException(500, _("Conversation record does not exist"))
+ chat_record = QuerySet(ChatRecord).filter(id=chat_record_id).first()
+ return chat_record
+
+ def chat_work_flow(self, chat_info: ChatInfo, instance: dict, base_to_response):
+ message = self.data.get('message')
+ re_chat = self.data.get('re_chat')
+ stream = self.data.get('stream')
+ chat_user_id = instance.get('chat_user_id')
+ chat_user_type = instance.get('chat_user_type')
+ form_data = self.data.get('form_data')
+ image_list = self.data.get('image_list')
+ document_list = self.data.get('document_list')
+ audio_list = self.data.get('audio_list')
+ other_list = self.data.get('other_list')
+ user_id = chat_info.application.user_id
+ chat_record_id = self.data.get('chat_record_id')
+ chat_record = None
+ history_chat_record = chat_info.chat_record_list
+ if chat_record_id is not None:
+ chat_record = self.get_chat_record(chat_info, chat_record_id)
+ history_chat_record = [r for r in chat_info.chat_record_list if str(r.id) != chat_record_id]
+ work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow),
+ {'history_chat_record': history_chat_record, 'question': message,
+ 'chat_id': chat_info.chat_id, 'chat_record_id': str(
+ uuid.uuid1()) if chat_record is None else chat_record.id,
+ 'stream': stream,
+ 're_chat': re_chat,
+ 'chat_user_id': chat_user_id,
+ 'chat_user_type': chat_user_type,
+ 'user_id': user_id},
+ WorkFlowPostHandler(chat_info, chat_user_id, chat_user_type),
+ base_to_response, form_data, image_list, document_list, audio_list,
+ other_list,
+ self.data.get('runtime_node_id'),
+ self.data.get('node_data'), chat_record, self.data.get('child_node'))
+ r = work_flow_manage.run()
+ return r
+
+ def chat(self, instance: dict, base_to_response: BaseToResponse = SystemToResponse()):
+ super().is_valid(raise_exception=True)
+ ChatMessageSerializers(data=instance).is_valid(raise_exception=True)
+ chat_info = self.get_chat_info()
+ self.is_valid_chat_id(chat_info)
+ if chat_info.application.type == ApplicationTypeChoices.SIMPLE:
+ self.is_valid_application_simple(raise_exception=True, chat_info=chat_info),
+ return self.chat_simple(chat_info, instance, base_to_response)
+ else:
+ self.is_valid_application_workflow(raise_exception=True)
+ return self.chat_work_flow(chat_info, instance, base_to_response)
+
+ def get_chat_info(self):
+ self.is_valid(raise_exception=True)
+ chat_id = self.data.get('chat_id')
+ chat_info: ChatInfo = ChatInfo.get_cache(chat_id)
+ if chat_info is None:
+ chat_info: ChatInfo = self.re_open_chat(chat_id)
+ chat_info.set_cache()
+ return chat_info
+
+ def re_open_chat(self, chat_id: str):
+ chat = QuerySet(Chat).filter(id=chat_id).first()
+ if chat is None:
+ raise ChatException(500, _("Conversation does not exist"))
+ application = QuerySet(Application).filter(id=chat.application_id).first()
+ if application is None:
+ raise ChatException(500, _("Application does not exist"))
+ if application.type == ApplicationTypeChoices.SIMPLE:
+ return self.re_open_chat_simple(chat_id, application)
+ else:
+ return self.re_open_chat_work_flow(chat_id, application)
+
+ def re_open_chat_simple(self, chat_id, application):
+ # 数据集id列表
+ knowledge_id_list = [str(row.dataset_id) for row in
+ QuerySet(ApplicationKnowledgeMapping).filter(
+ application_id=application.id)]
+
+ # 需要排除的文档
+ exclude_document_id_list = [str(document.id) for document in
+ QuerySet(Document).filter(
+ knowledge_id__in=knowledge_id_list,
+ is_active=False)]
+ chat_info = ChatInfo(chat_id, self.data.get('chat_user_id'), self.data.get('chat_user_type'), knowledge_id_list,
+ exclude_document_id_list, application)
+ chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5])
+ chat_record_list.sort(key=lambda r: r.create_time)
+ for chat_record in chat_record_list:
+ chat_info.chat_record_list.append(chat_record)
+ return chat_info
+
+ def re_open_chat_work_flow(self, chat_id, application):
+ work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application.id).order_by(
+ '-create_time')[0:1].first()
+ if work_flow_version is None:
+ raise ChatException(500, _("The application has not been published. Please use it after publishing."))
+
+ chat_info = ChatInfo(chat_id, self.data.get('chat_user_id'), self.data.get('chat_user_type'), [], [],
+ application, work_flow_version)
+ chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5])
+ chat_record_list.sort(key=lambda r: r.create_time)
+ for chat_record in chat_record_list:
+ chat_info.chat_record_list.append(chat_record)
+ return chat_info
+
+
+class OpenChatSerializers(serializers.Serializer):
+ workspace_id = serializers.CharField(required=True)
+ application_id = serializers.UUIDField(required=True)
+ chat_user_id = serializers.CharField(required=True, label=_("Client id"))
+ chat_user_type = serializers.CharField(required=True, label=_("Client Type"))
+
+ def is_valid(self, *, raise_exception=False):
+ super().is_valid(raise_exception=True)
+ workspace_id = self.data.get('workspace_id')
+ application_id = self.data.get('application_id')
+ if not QuerySet(Application).filter(id=application_id, workspace_id=workspace_id).exists():
+ raise AppApiException(500, gettext('Application does not exist'))
+
+ def open(self):
+ self.is_valid(raise_exception=True)
+ application_id = self.data.get('application_id')
+ application = QuerySet(Application).get(id=application_id)
+ if application.type == ApplicationTypeChoices.SIMPLE:
+ return self.open_simple(application)
+ else:
+ return self.open_work_flow(application)
+
+ def open_work_flow(self, application):
+ self.is_valid(raise_exception=True)
+ application_id = self.data.get('application_id')
+ chat_user_id = self.data.get("chat_user_id")
+ chat_user_type = self.data.get("chat_user_type")
+ chat_id = str(uuid.uuid7())
+ work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application_id).order_by(
+ '-create_time')[0:1].first()
+ if work_flow_version is None:
+ raise AppApiException(500,
+ gettext(
+ "The application has not been published. Please use it after publishing."))
+ ChatInfo(chat_id, chat_user_id, chat_user_type, [],
+ [],
+ application, work_flow_version).set_cache()
+ return chat_id
+
+ def open_simple(self, application):
+ application_id = self.data.get('application_id')
+ chat_user_id = self.data.get("chat_user_id")
+ chat_user_type = self.data.get("chat_user_type")
+ knowledge_id_list = [str(row.dataset_id) for row in
+ QuerySet(ApplicationKnowledgeMapping).filter(
+ application_id=application_id)]
+ chat_id = str(uuid.uuid7())
+ ChatInfo(chat_id, chat_user_id, chat_user_type, knowledge_id_list,
+ [str(document.id) for document in
+ QuerySet(Document).filter(
+ knowledge_id__in=knowledge_id_list,
+ is_active=False)],
+ application).set_cache()
+ return chat_id
diff --git a/apps/chat/serializers/chat_authentication.py b/apps/chat/serializers/chat_authentication.py
index 92dfab92c..bbcd57846 100644
--- a/apps/chat/serializers/chat_authentication.py
+++ b/apps/chat/serializers/chat_authentication.py
@@ -14,43 +14,18 @@ from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from rest_framework import serializers
-from application.models import ApplicationAccessToken, ClientType, Application, ApplicationTypeChoices, WorkFlowVersion
+from application.models import ApplicationAccessToken, ChatUserType, Application, ApplicationTypeChoices, \
+ WorkFlowVersion
from application.serializers.application import ApplicationSerializerModel
from common.auth.common import ChatUserToken, ChatAuthentication
-
from common.constants.authentication_type import AuthenticationType
from common.constants.cache_version import Cache_Version
from common.database_model_manage.database_model_manage import DatabaseModelManage
-from common.exception.app_exception import NotFound404, AppApiException, AppUnauthorizedFailed
+from common.exception.app_exception import NotFound404, AppUnauthorizedFailed
-def auth(application_id, access_token, authentication_value, token_details):
- client_id = token_details.get('client_id')
- if client_id is None:
- client_id = str(uuid.uuid1())
- _type = AuthenticationType.CHAT_ANONYMOUS_USER
- if authentication_value is not None:
- application_setting_model = DatabaseModelManage.get_model('application_setting')
- if application_setting_model is not None:
- application_setting = QuerySet(application_setting_model).filter(application_id=application_id).first()
- if application_setting.authentication:
- auth_type = application_setting.authentication_value.get('type')
- auth_value = authentication_value.get(auth_type + '_value')
- if auth_type == 'password':
- if authentication_value.get('type') == 'password':
- if auth_value == authentication_value.get(auth_type + '_value'):
- return ChatUserToken(application_id, None, access_token, _type, ClientType.ANONYMOUS_USER,
- client_id, ChatAuthentication(auth_type, True, True))
- else:
- raise AppApiException(500, '认证方式不匹配')
- return ChatUserToken(application_id, None, access_token, _type, ClientType.ANONYMOUS_USER,
- client_id, ChatAuthentication(None, False, False))
-
-
-class AuthenticationSerializer(serializers.Serializer):
+class AnonymousAuthenticationSerializer(serializers.Serializer):
access_token = serializers.CharField(required=True, label=_("access_token"))
- authentication_value = serializers.JSONField(required=False, allow_null=True,
- label=_("Certification Information"))
def auth(self, request, with_valid=True):
token = request.META.get('HTTP_AUTHORIZATION')
@@ -65,16 +40,42 @@ class AuthenticationSerializer(serializers.Serializer):
self.is_valid(raise_exception=True)
access_token = self.data.get("access_token")
application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
- authentication_value = self.data.get('authentication_value', None)
if application_access_token is not None and application_access_token.is_active:
- chat_user_token = auth(application_access_token.application_id, access_token, authentication_value,
- token_details)
-
- return chat_user_token.to_token()
+ chat_user_id = token_details.get('chat_user_id') or str(uuid.uuid1())
+ _type = AuthenticationType.CHAT_ANONYMOUS_USER
+ return ChatUserToken(application_access_token.application_id, None, access_token, _type,
+ ChatUserType.ANONYMOUS_USER,
+ chat_user_id, ChatAuthentication(None, False, False)).to_token()
else:
raise NotFound404(404, _("Invalid access_token"))
+class AuthProfileSerializer(serializers.Serializer):
+ access_token = serializers.CharField(required=True, label=_("access_token"))
+
+ def profile(self):
+ self.is_valid(raise_exception=True)
+ access_token = self.data.get("access_token")
+ application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first()
+ application_id = application_access_token.application_id
+ profile = {
+ 'authentication': False
+ }
+ application_setting_model = DatabaseModelManage.get_model('application_setting')
+ if application_setting_model:
+ application_setting = QuerySet(application_setting_model).filter(application_id=application_id).first()
+ profile = {
+ 'icon': application_setting.application.icon,
+ 'application_name': application_setting.application.name,
+ 'bg_icon': application_setting.chat_background,
+ 'authentication': application_setting.authentication,
+ 'authentication_type': application_setting.authentication_value.get(
+ 'type', 'password'),
+ 'login_value': application_setting.authentication_value.get('login_value', [])
+ }
+ return profile
+
+
class ApplicationProfileSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True, label=_("Application ID"))
@@ -119,11 +120,6 @@ class ApplicationProfileSerializer(serializers.Serializer):
'avatar': application_setting.avatar,
'show_avatar': application_setting.show_avatar,
'float_icon': application_setting.float_icon,
- 'authentication': application_setting.authentication,
- 'authentication_type': application_setting.authentication_value.get(
- 'type', 'password'),
- 'login_value': application_setting.authentication_value.get(
- 'login_value', []),
'disclaimer': application_setting.disclaimer,
'disclaimer_value': application_setting.disclaimer_value,
'custom_theme': application_setting.custom_theme,
diff --git a/apps/chat/urls.py b/apps/chat/urls.py
index 0f0bc12a2..8550572b6 100644
--- a/apps/chat/urls.py
+++ b/apps/chat/urls.py
@@ -6,6 +6,9 @@ app_name = 'chat'
urlpatterns = [
path('chat/embed', views.ChatEmbedView.as_view()),
- path('application/authentication', views.Authentication.as_view()),
- path('profile', views.ApplicationProfile.as_view())
+ path('application/anonymous_authentication', views.AnonymousAuthentication.as_view()),
+ path('auth/profile', views.AuthProfile.as_view()),
+ path('profile', views.ApplicationProfile.as_view()),
+ path('chat_message/', views.ChatView.as_view()),
+ path('workspace//application//open', views.OpenView.as_view())
]
diff --git a/apps/chat/views/chat.py b/apps/chat/views/chat.py
index bda65408e..ed2dd71ff 100644
--- a/apps/chat/views/chat.py
+++ b/apps/chat/views/chat.py
@@ -12,14 +12,18 @@ from drf_spectacular.utils import extend_schema
from rest_framework.request import Request
from rest_framework.views import APIView
-from chat.api.chat_authentication_api import ChatAuthenticationAPI
-from chat.serializers.chat_authentication import AuthenticationSerializer, ApplicationProfileSerializer
+from chat.api.chat_api import ChatAPI
+from chat.api.chat_authentication_api import ChatAuthenticationAPI, ChatAuthenticationProfileAPI, ChatOpenAPI
+from chat.serializers.chat import OpenChatSerializers, ChatSerializers
+from chat.serializers.chat_authentication import AnonymousAuthenticationSerializer, ApplicationProfileSerializer, \
+ AuthProfileSerializer
from common.auth import TokenAuth
+from common.constants.permission_constants import ChatAuth
from common.exception.app_exception import AppAuthenticationFailed
from common.result import result
-class Authentication(APIView):
+class AnonymousAuthentication(APIView):
def options(self, request, *args, **kwargs):
return HttpResponse(
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
@@ -28,18 +32,16 @@ class Authentication(APIView):
@extend_schema(
methods=['POST'],
- description=_('Application Certification'),
- summary=_('Application Certification'),
- operation_id=_('Application Certification'), # type: ignore
+ description=_('Application Anonymous Certification'),
+ summary=_('Application Anonymous Certification'),
+ operation_id=_('Application Anonymous Certification'), # type: ignore
request=ChatAuthenticationAPI.get_request(),
responses=None,
tags=[_('Chat')] # type: ignore
)
def post(self, request: Request):
return result.success(
- AuthenticationSerializer(data={'access_token': request.data.get("access_token"),
- 'authentication_value': request.data.get(
- 'authentication_value')}).auth(
+ AnonymousAuthenticationSerializer(data={'access_token': request.data.get("access_token")}).auth(
request),
headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true",
"Access-Control-Allow-Methods": "POST",
@@ -60,11 +62,61 @@ class ApplicationProfile(APIView):
tags=[_('Chat')] # type: ignore
)
def get(self, request: Request):
- if 'application_id' in request.auth.keywords:
+ if isinstance(request.auth, ChatAuth):
return result.success(ApplicationProfileSerializer(
- data={'application_id': request.auth.keywords.get('application_id')}).profile())
+ data={'application_id': request.auth.application_id}).profile())
raise AppAuthenticationFailed(401, "身份异常")
+class AuthProfile(APIView):
+ @extend_schema(
+ methods=['GET'],
+ description=_("Get application authentication information"),
+ summary=_("Get application authentication information"),
+ operation_id=_("Get application authentication information"), # type: ignore
+ parameters=ChatAuthenticationProfileAPI.get_parameters(),
+ responses=None,
+ tags=[_('Chat')] # type: ignore
+ )
+ def get(self, request: Request):
+ return result.success(
+ AuthProfileSerializer(data={'access_token': request.query_params.get("access_token")}).profile())
+
+
class ChatView(APIView):
- pass
+ authentication_classes = [TokenAuth]
+
+ @extend_schema(
+ methods=['POST'],
+ description=_("dialogue"),
+ summary=_("dialogue"),
+ operation_id=_("dialogue"), # type: ignore
+ request=ChatAPI.get_request(),
+ parameters=ChatAPI.get_parameters(),
+ responses=None,
+ tags=[_('Chat')] # type: ignore
+ )
+ def post(self, request: Request, chat_id: str):
+ return ChatSerializers(data={'chat_id': chat_id,
+ 'chat_user_id': request.auth.chat_user_id,
+ 'chat_user_type': request.auth.chat_user_type,
+ 'application_id': request.auth.application_id}
+ ).chat(request.data)
+
+
+class OpenView(APIView):
+ authentication_classes = [TokenAuth]
+
+ @extend_schema(
+ methods=['GET'],
+ description=_("Get the session id according to the application id"),
+ summary=_("Get the session id according to the application id"),
+ operation_id=_("Get the session id according to the application id"), # type: ignore
+ parameters=ChatOpenAPI.get_parameters(),
+ responses=None,
+ tags=[_('Chat')] # type: ignore
+ )
+ def get(self, request: Request, workspace_id: str, application_id: str):
+ return result.success(OpenChatSerializers(
+ data={'workspace_id': workspace_id, 'application_id': application_id,
+ 'chat_user_id': request.auth.chat_user_id, 'chat_user_type': request.auth.chat_user_type}).open())
diff --git a/apps/common/auth/common.py b/apps/common/auth/common.py
index 700ed12b4..3b37f28e2 100644
--- a/apps/common/auth/common.py
+++ b/apps/common/auth/common.py
@@ -32,14 +32,14 @@ class ChatAuthentication:
class ChatUserToken:
- def __init__(self, application_id, user_id, access_token, _type, client_type, client_id,
+ def __init__(self, application_id, user_id, access_token, _type, chat_user_type, chat_user_id,
authentication: ChatAuthentication):
self.application_id = application_id
self.user_id = user_id,
self.access_token = access_token
self.type = _type
- self.client_type = client_type
- self.client_id = client_id
+ self.chat_user_type = chat_user_type
+ self.chat_user_id = chat_user_id
self.authentication = authentication
def to_dict(self):
@@ -48,8 +48,8 @@ class ChatUserToken:
'user_id': str(self.user_id),
'access_token': self.access_token,
'type': str(self.type.value),
- 'client_type': str(self.client_type),
- 'client_id': str(self.client_id),
+ 'chat_user_type': str(self.chat_user_type),
+ 'chat_user_id': str(self.chat_user_id),
'authentication': self.authentication.to_string()
}
@@ -59,6 +59,6 @@ class ChatUserToken:
@staticmethod
def new_instance(token_dict):
return ChatUserToken(token_dict.get('application_id'), token_dict.get('user_id'),
- token_dict.get('access_token'), token_dict.get('type'), token_dict.get('client_type'),
- token_dict.get('client_id'),
+ token_dict.get('access_token'), token_dict.get('type'), token_dict.get('chat_user_type'),
+ token_dict.get('chat_user_id'),
ChatAuthentication.new_instance(token_dict.get('authentication')))
diff --git a/apps/common/auth/handle/impl/chat_anonymous_user_token.py b/apps/common/auth/handle/impl/chat_anonymous_user_token.py
index 0d8da19ce..2e00314f1 100644
--- a/apps/common/auth/handle/impl/chat_anonymous_user_token.py
+++ b/apps/common/auth/handle/impl/chat_anonymous_user_token.py
@@ -9,12 +9,11 @@
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
-from application.models import ApplicationAccessToken, ClientType
+from application.models import ApplicationAccessToken
from common.auth.common import ChatUserToken
-
from common.auth.handle.auth_base_handle import AuthBaseHandle
from common.constants.authentication_type import AuthenticationType
-from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, Auth
+from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, ChatAuth
from common.exception.app_exception import AppAuthenticationFailed, ChatException
@@ -45,11 +44,11 @@ class ChatAnonymousUserToken(AuthBaseHandle):
if request.path != '/api/application/profile':
if chat_user_token.authentication.is_auth and not chat_user_token.authentication.auth_passed:
raise ChatException(1002, _('Authentication information is incorrect'))
- return None, Auth(
+ return None, ChatAuth(
current_role_list=[RoleConstants.CHAT_ANONYMOUS_USER],
permission_list=[
Permission(group=Group.APPLICATION,
operate=Operate.USE)],
application_id=application_access_token.application_id,
- client_id=auth_details.get('client_id'),
- client_type=ClientType.ANONYMOUS_USER)
+ chat_user_id=chat_user_token.chat_user_id,
+ chat_user_type=chat_user_token.chat_user_type)
diff --git a/apps/common/constants/cache_version.py b/apps/common/constants/cache_version.py
index 6b7b74ad8..b30c22414 100644
--- a/apps/common/constants/cache_version.py
+++ b/apps/common/constants/cache_version.py
@@ -27,6 +27,9 @@ class Cache_Version(Enum):
# 应用对接三方应用的缓存
APPLICATION_THIRD_PARTY = "APPLICATION:THIRD_PARTY", lambda key: key
+ # 对话
+ CHAT = "CHAT", lambda key: key
+
def get_version(self):
return self.value[0]
diff --git a/apps/common/constants/permission_constants.py b/apps/common/constants/permission_constants.py
index 2617472ce..bdf81942c 100644
--- a/apps/common/constants/permission_constants.py
+++ b/apps/common/constants/permission_constants.py
@@ -925,6 +925,22 @@ def get_permission_list_by_resource_group(resource_group: ResourcePermissionGrou
PermissionConstants[k].value.resource_permission_group_list.__contains__(resource_group)]
+class ChatAuth:
+ def __init__(self,
+ current_role_list: List[RoleConstants | Role],
+ permission_list: List[PermissionConstants | Permission],
+ chat_user_id,
+ chat_user_type,
+ application_id):
+ # 权限列表
+ self.permission_list = permission_list
+ # 角色列表
+ self.role_list = current_role_list
+ self.chat_user_id = chat_user_id
+ self.chat_user_type = chat_user_type
+ self.application_id = application_id
+
+
class Auth:
"""
用于存储当前用户的角色和权限