diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index f5c50bc54..0919abbe2 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -154,6 +154,7 @@ class BaseChatStep(IChatStep): chat_result = iter(directly_return_chunk_list) else: chat_result = chat_model.stream(message_list) + is_ai_chat = True else: chat_result = chat_model.stream(message_list) is_ai_chat = True @@ -187,8 +188,18 @@ class BaseChatStep(IChatStep): 'status') == 'designated_answer': chat_result = AIMessage(content=no_references_setting.get('value')) else: - chat_result = chat_model.invoke(message_list) - is_ai_chat = True + if paragraph_list is not None and len(paragraph_list) > 0: + directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content) + for paragraph in paragraph_list if + paragraph.hit_handling_method == 'directly_return'] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + chat_result = iter(directly_return_chunk_list) + else: + chat_result = chat_model.invoke(message_list) + is_ai_chat = True + else: + chat_result = chat_model.invoke(message_list) + is_ai_chat = True chat_record_id = uuid.uuid1() if is_ai_chat: request_token = chat_model.get_num_tokens_from_messages(message_list) diff --git a/apps/application/migrations/0004_applicationaccesstoken_show_source.py b/apps/application/migrations/0004_applicationaccesstoken_show_source.py new file mode 100644 index 000000000..851d7315d --- /dev/null +++ b/apps/application/migrations/0004_applicationaccesstoken_show_source.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-04-25 11:28 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0003_application_icon'), + ] + + operations = [ + migrations.AddField( + model_name='applicationaccesstoken', + name='show_source', + field=models.BooleanField(default=False, verbose_name='是否显示知识来源'), + ), + ] diff --git a/apps/application/models/api_key_model.py b/apps/application/models/api_key_model.py index 1090d2e57..117ab0664 100644 --- a/apps/application/models/api_key_model.py +++ b/apps/application/models/api_key_model.py @@ -39,6 +39,7 @@ class ApplicationAccessToken(AppModelMixin): white_list = ArrayField(verbose_name="白名单列表", base_field=models.CharField(max_length=128, blank=True) , default=list) + show_source = models.BooleanField(default=False, verbose_name="是否显示知识来源") class Meta: db_table = "application_access_token" diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index e60aaa320..3307d873a 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -28,7 +28,7 @@ from common.config.embedding_config import VectorStore, EmbeddingModel from common.constants.authentication_type import AuthenticationType from common.db.search import get_dynamics_model, native_search, native_page_search from common.db.sql_execute import select_list -from common.exception.app_exception import AppApiException, NotFound404 +from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed from common.field.common import UploadedImageField from common.util.field_message import ErrMessage from common.util.file_util import get_file_content @@ -170,7 +170,9 @@ class ApplicationSerializer(serializers.Serializer): white_list = serializers.ListSerializer(required=False, child=serializers.CharField(required=True, error_messages=ErrMessage.char( "白名单")), - error_messages=ErrMessage.list("白名单列表")) + error_messages=ErrMessage.list("白名单列表")), + show_source = serializers.BooleanField(required=False, + error_messages=ErrMessage.boolean("是否显示知识来源")) def edit(self, instance: Dict, with_valid=True): if with_valid: @@ -190,6 +192,8 @@ class ApplicationSerializer(serializers.Serializer): application_access_token.white_active = instance.get("white_active") if 'white_list' in instance and instance.get('white_list') is not None: application_access_token.white_list = instance.get('white_list') + if 'show_source' in instance and instance.get('show_source') is not None: + application_access_token.show_source = instance.get('show_source') application_access_token.save() return self.one(with_valid=False) @@ -210,7 +214,8 @@ class ApplicationSerializer(serializers.Serializer): "is_active": application_access_token.is_active, 'access_num': application_access_token.access_num, 'white_active': application_access_token.white_active, - 'white_list': application_access_token.white_list + 'white_list': application_access_token.white_list, + 'show_source': application_access_token.show_source } class Authentication(serializers.Serializer): @@ -474,8 +479,12 @@ class ApplicationSerializer(serializers.Serializer): self.is_valid() application_id = self.data.get("application_id") application = QuerySet(Application).get(id=application_id) + application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application.id).first() + if application_access_token is None: + raise AppUnauthorizedFailed(500, "非法用户") return ApplicationSerializer.Query.reset_application( - ApplicationSerializer.ApplicationModel(application).data) + {**ApplicationSerializer.ApplicationModel(application).data, + 'show_source': application_access_token.show_source}) def edit(self, instance: Dict, with_valid=True): if with_valid: diff --git a/apps/application/serializers/chat_serializers.py b/apps/application/serializers/chat_serializers.py index dc69c1475..6c4a68707 100644 --- a/apps/application/serializers/chat_serializers.py +++ b/apps/application/serializers/chat_serializers.py @@ -23,6 +23,7 @@ from django.http import HttpResponse from rest_framework import serializers from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord +from application.models.api_key_model import ApplicationAccessToken from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \ ModelSettingSerializer from application.serializers.chat_message_serializers import ChatInfo @@ -277,17 +278,27 @@ class ChatRecordSerializerModel(serializers.ModelSerializer): class ChatRecordSerializer(serializers.Serializer): class Operate(serializers.Serializer): chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) - + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id")) + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=self.data.get('application_id')).first() + if application_access_token is None: + raise AppApiException(500, '不存在的应用认证信息') + if not application_access_token.show_source: + raise AppApiException(500, '未开启显示知识来源') + def get_chat_record(self): chat_record_id = self.data.get('chat_record_id') chat_id = self.data.get('chat_id') chat_info: ChatInfo = chat_cache.get(chat_id) - chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if - chat_record.id == uuid.UUID(chat_record_id)] - if chat_record_list is not None and len(chat_record_list): - return chat_record_list[-1] + if chat_info is not None: + chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if + chat_record.id == uuid.UUID(chat_record_id)] + if chat_record_list is not None and len(chat_record_list): + return chat_record_list[-1] return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() def one(self, with_valid=True): diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index e92d67801..42f2c6a9b 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -132,6 +132,8 @@ class ApplicationApi(ApiMixin): 'white_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), title="白名单列表", description="白名单列表"), + 'show_source': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否显示知识来源", + description="是否显示知识来源"), } ) diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 7e968038f..ff8eabef4 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -172,7 +172,8 @@ class ChatView(APIView): tags=["应用/对话日志"] ) @has_permissions( - ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, + RoleConstants.APPLICATION_ACCESS_TOKEN], [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, dynamic_tag=keywords.get('application_id'))]) )