From f9a76d79484062741199bf910e243615f3b6fa7b Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Mon, 9 Sep 2024 14:47:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81openai=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=20#353=20(#1128)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chat_pipeline/pipeline_manage.py | 16 +++- .../step/chat_step/impl/base_chat_step.py | 25 +++--- apps/application/flow/workflow_manage.py | 52 ++++++++--- .../serializers/chat_message_serializers.py | 88 ++++++++++++++++--- apps/application/swagger_api/chat_api.py | 25 ++++++ apps/application/urls.py | 8 +- apps/application/views/chat_views.py | 20 ++++- apps/common/auth/authenticate.py | 20 +++++ apps/common/exception/app_exception.py | 8 ++ apps/common/handle/base_to_response.py | 27 ++++++ .../impl/response/openai_to_response.py | 42 +++++++++ .../impl/response/system_to_response.py | 26 ++++++ .../local_model_provider/model/reranker.py | 4 + .../impl/openai_model_provider/model/llm.py | 4 +- 14 files changed, 317 insertions(+), 48 deletions(-) create mode 100644 apps/common/handle/base_to_response.py create mode 100644 apps/common/handle/impl/response/openai_to_response.py create mode 100644 apps/common/handle/impl/response/system_to_response.py diff --git a/apps/application/chat_pipeline/pipeline_manage.py b/apps/application/chat_pipeline/pipeline_manage.py index 37d7736b5..7c4acb3a3 100644 --- a/apps/application/chat_pipeline/pipeline_manage.py +++ b/apps/application/chat_pipeline/pipeline_manage.py @@ -11,14 +11,18 @@ from functools import reduce from typing import List, Type, Dict from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from common.handle.base_to_response import BaseToResponse +from common.handle.impl.response.system_to_response import SystemToResponse class PipelineManage: - def __init__(self, step_list: List[Type[IBaseChatPipelineStep]]): + def __init__(self, step_list: List[Type[IBaseChatPipelineStep]], + base_to_response: BaseToResponse = SystemToResponse()): # 步骤执行器 self.step_list = [step() for step in step_list] # 上下文 self.context = {'message_tokens': 0, 'answer_tokens': 0} + self.base_to_response = base_to_response def run(self, context: Dict = None): self.context['start_time'] = time.time() @@ -33,13 +37,21 @@ class PipelineManage: filter(lambda r: r is not None, [row.get_details(self) for row in self.step_list])], {}) + def get_base_to_response(self): + return self.base_to_response + class builder: def __init__(self): self.step_list: List[Type[IBaseChatPipelineStep]] = [] + self.base_to_response = SystemToResponse() def append_step(self, step: Type[IBaseChatPipelineStep]): self.step_list.append(step) return self + def add_base_to_response(self, base_to_response: BaseToResponse): + self.base_to_response = base_to_response + return self + def build(self): - return PipelineManage(step_list=self.step_list) + return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response) diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index a4bced004..4cad1796e 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -6,7 +6,6 @@ @date:2024/1/9 18:25 @desc: 对话step Base实现 """ -import json import logging import time import traceback @@ -19,13 +18,13 @@ from langchain.chat_models.base import BaseChatModel from langchain.schema import BaseMessage from langchain.schema.messages import HumanMessage, AIMessage from langchain_core.messages import AIMessageChunk +from rest_framework import status from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.pipeline_manage import PipelineManage from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler from application.models.api_key_model import ApplicationPublicAccessClient from common.constants.authentication_type import AuthenticationType -from common.response import result from setting.models_provider.tools import get_model_instance_by_model_user_id @@ -66,9 +65,9 @@ def event_content(response, try: for chunk in response: all_text += chunk.content - yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, - 'content': chunk.content, 'is_end': False}) + "\n\n" - + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), chunk.content, + False, + 0, 0) # 获取token if is_ai_chat: try: @@ -83,8 +82,8 @@ def event_content(response, write_context(step, manage, request_token, response_token, all_text) post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, all_text, manage, step, padding_problem_text, client_id) - yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, - 'content': '', 'is_end': True}) + "\n\n" + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), '', True, + request_token, response_token) add_access_num(client_id, client_type) except Exception as e: logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') @@ -93,8 +92,7 @@ def event_content(response, post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, all_text, manage, step, padding_problem_text, client_id) add_access_num(client_id, client_type) - yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, - 'content': all_text, 'is_end': True}) + "\n\n" + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), all_text, True, 0, 0) class BaseChatStep(IChatStep): @@ -234,13 +232,14 @@ class BaseChatStep(IChatStep): post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, chat_result.content, manage, self, padding_problem_text, client_id) add_access_num(client_id, client_type) - return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, - 'content': chat_result.content, 'is_end': True}) + return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), + chat_result.content, True, + request_token, response_token) except Exception as e: all_text = '异常' + str(e) write_context(self, manage, 0, 0, all_text) post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, all_text, manage, self, padding_problem_text, client_id) add_access_num(client_id, client_type) - return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, - 'content': all_text, 'is_end': True}) + return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), all_text, True, 0, + 0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 65097d788..e6a1b4117 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -11,14 +11,16 @@ from functools import reduce from typing import List, Dict from django.db.models import QuerySet -from langchain_core.messages import AIMessage from langchain_core.prompts import PromptTemplate +from rest_framework import status from rest_framework.exceptions import ErrorDetail, ValidationError from application.flow import tools from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult from application.flow.step_node import get_node from common.exception.app_exception import AppApiException +from common.handle.base_to_response import BaseToResponse +from common.handle.impl.response.system_to_response import SystemToResponse from function_lib.models.function import FunctionLib from setting.models import Model from setting.models_provider import get_model_credential @@ -163,7 +165,8 @@ class Flow: class WorkflowManage: - def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler): + def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, + base_to_response: BaseToResponse = SystemToResponse()): self.params = params self.flow = flow self.context = {} @@ -172,6 +175,7 @@ class WorkflowManage: self.current_node = None self.current_result = None self.answer = "" + self.base_to_response = base_to_response def run(self): if self.params.get('stream'): @@ -189,13 +193,19 @@ class WorkflowManage: if result is not None: list(result) if not self.has_next_node(self.current_result): - return tools.to_response_simple(self.params['chat_id'], self.params['chat_record_id'], - AIMessage(self.answer), self, - self.work_flow_post_handler) + details = self.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + return self.base_to_response.to_block_response(self.params['chat_id'], + self.params['chat_record_id'], self.answer, True + , message_tokens, answer_tokens) except Exception as e: - return tools.to_response(self.params['chat_id'], self.params['chat_record_id'], - AIMessage(str(e)), self, self.current_node.get_write_error_context(e), - self.work_flow_post_handler) + self.current_node.get_write_error_context(e) + return self.base_to_response.to_block_response(self.params['chat_id'], self.params['chat_record_id'], + str(e), True, + 0, 0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR) def run_stream(self): return tools.to_stream_response_simple(self.stream_event()) @@ -208,16 +218,29 @@ class WorkflowManage: self.current_node.valid_args(self.current_node.node_params, self.current_node.workflow_params) self.current_result = self.current_node.run() result = self.current_result.write_context(self.current_node, self) + has_next_node = self.has_next_node(self.current_result) if result is not None: if self.is_result(): for r in result: - yield self.get_chunk_content(r) - yield self.get_chunk_content('\n') - self.answer += '\n' + yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + r, False, 0, 0) + if has_next_node: + yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + '\n', False, 0, 0) + self.answer += '\n' else: list(result) - if not self.has_next_node(self.current_result): - yield self.get_chunk_content('', True) + if not has_next_node: + details = self.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + '', True, message_tokens, answer_tokens) break self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], self.answer, @@ -228,7 +251,8 @@ class WorkflowManage: self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], self.answer, self) - yield self.get_chunk_content(str(e), True) + yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'], + str(e), True, 0, 0) def is_result(self): """ diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 4a2e7e47c..a994481ab 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -7,7 +7,7 @@ @desc: """ import uuid -from typing import List +from typing import List, Dict from uuid import UUID from django.core.cache import caches @@ -27,7 +27,10 @@ from application.models import ChatRecord, Chat, Application, ApplicationDataset WorkFlowVersion from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken from common.constants.authentication_type import AuthenticationType -from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed +from common.exception.app_exception import AppChatNumOutOfBoundsFailed, ChatException +from common.handle.base_to_response import BaseToResponse +from common.handle.impl.response.openai_to_response import OpenaiToResponse +from common.handle.impl.response.system_to_response import SystemToResponse from common.util.field_message import ErrMessage from common.util.split_model import flat_map from dataset.models import Paragraph, Document @@ -145,6 +148,58 @@ def get_post_handler(chat_info: ChatInfo): return PostHandler() +class OpenAIMessage(serializers.Serializer): + content = serializers.CharField(required=True, error_messages=ErrMessage.char('内容')) + role = serializers.CharField(required=True, error_messages=ErrMessage.char('角色')) + + +class OpenAIInstanceSerializer(serializers.Serializer): + messages = serializers.ListField(child=OpenAIMessage()) + chat_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char("对话id")) + re_chat = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("重新生成")) + stream = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("流式输出")) + + +class OpenAIChatSerializer(serializers.Serializer): + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) + client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) + + @staticmethod + def get_message(instance): + return instance.get('messages')[-1].get('content') + + @staticmethod + def generate_chat(chat_id, application_id, message, client_id): + if chat_id is None: + chat_id = str(uuid.uuid1()) + chat = QuerySet(Chat).filter(id=chat_id).first() + if chat is None: + Chat(id=chat_id, application_id=application_id, abstract=message, client_id=client_id).save() + return chat_id + + def chat(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + OpenAIInstanceSerializer(data=instance).is_valid(raise_exception=True) + chat_id = instance.get('chat_id') + message = self.get_message(instance) + re_chat = instance.get('re_chat', False) + stream = instance.get('stream', False) + application_id = self.data.get('application_id') + client_id = self.data.get('client_id') + client_type = self.data.get('client_type') + chat_id = self.generate_chat(chat_id, application_id, message, client_id) + return ChatMessageSerializer( + data={'chat_id': chat_id, 'message': message, + 're_chat': re_chat, + 'stream': stream, + 'application_id': application_id, + 'client_id': client_id, + 'client_type': client_type}).chat( + base_to_response=OpenaiToResponse()) + + class ChatMessageSerializer(serializers.Serializer): chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id")) message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题")) @@ -157,6 +212,10 @@ class ChatMessageSerializer(serializers.Serializer): def is_valid_application_workflow(self, *, raise_exception=False): self.is_valid_intraday_access_num() + def is_valid_chat_id(self, chat_info: ChatInfo): + if self.data.get('application_id') != str(chat_info.application.id): + raise ChatException(500, "会话不存在") + def is_valid_intraday_access_num(self): if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first() @@ -181,12 +240,12 @@ class ChatMessageSerializer(serializers.Serializer): if model is None: return chat_info if model.status == Status.ERROR: - raise AppApiException(500, "当前模型不可用") + raise ChatException(500, "当前模型不可用") if model.status == Status.DOWNLOAD: - raise AppApiException(500, "模型正在下载中,请稍后再发起对话") + raise ChatException(500, "模型正在下载中,请稍后再发起对话") return chat_info - def chat_simple(self, chat_info: ChatInfo): + def chat_simple(self, chat_info: ChatInfo, base_to_response): message = self.data.get('message') re_chat = self.data.get('re_chat') stream = self.data.get('stream') @@ -200,6 +259,7 @@ class ChatMessageSerializer(serializers.Serializer): pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep) .append_step(BaseGenerateHumanMessageStep) .append_step(BaseChatStep) + .add_base_to_response(base_to_response) .build()) exclude_paragraph_id_list = [] # 相同问题是否需要排除已经查询到的段落 @@ -217,7 +277,7 @@ class ChatMessageSerializer(serializers.Serializer): pipeline_message.run(params) return pipeline_message.context['chat_result'] - def chat_work_flow(self, chat_info: ChatInfo): + def chat_work_flow(self, chat_info: ChatInfo, base_to_response): message = self.data.get('message') re_chat = self.data.get('re_chat') stream = self.data.get('stream') @@ -229,19 +289,21 @@ class ChatMessageSerializer(serializers.Serializer): 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()), 'stream': stream, 're_chat': re_chat, - 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type)) + 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), + base_to_response) r = work_flow_manage.run() return r - def chat(self): + def chat(self, base_to_response: BaseToResponse = SystemToResponse()): super().is_valid(raise_exception=True) chat_info = self.get_chat_info() + self.is_valid_chat_id(chat_info) if chat_info.application.type == ApplicationTypeChoices.SIMPLE: self.is_valid_application_simple(raise_exception=True, chat_info=chat_info), - return self.chat_simple(chat_info) + return self.chat_simple(chat_info, base_to_response) else: self.is_valid_application_workflow(raise_exception=True) - return self.chat_work_flow(chat_info) + return self.chat_work_flow(chat_info, base_to_response) def get_chat_info(self): self.is_valid(raise_exception=True) @@ -256,10 +318,10 @@ class ChatMessageSerializer(serializers.Serializer): def re_open_chat(self, chat_id: str): chat = QuerySet(Chat).filter(id=chat_id).first() if chat is None: - raise AppApiException(500, "会话不存在") + raise ChatException(500, "会话不存在") application = QuerySet(Application).filter(id=chat.application_id).first() if application is None: - raise AppApiException(500, "应用不存在") + raise ChatException(500, "应用不存在") if application.type == ApplicationTypeChoices.SIMPLE: return self.re_open_chat_simple(chat_id, application) else: @@ -289,7 +351,7 @@ class ChatMessageSerializer(serializers.Serializer): work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application.id).order_by( '-create_time')[0:1].first() if work_flow_version is None: - raise AppApiException(500, "应用未发布,请发布后再使用") + raise ChatException(500, "应用未发布,请发布后再使用") chat_info = ChatInfo(chat_id, [], [], application, work_flow_version) chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5]) diff --git a/apps/application/swagger_api/chat_api.py b/apps/application/swagger_api/chat_api.py index 2ff8f8ac5..cc2a50097 100644 --- a/apps/application/swagger_api/chat_api.py +++ b/apps/application/swagger_api/chat_api.py @@ -23,6 +23,31 @@ class ChatClientHistoryApi(ApiMixin): ] +class OpenAIChatApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + required=['message'], + properties={ + 'messages': openapi.Schema(type=openapi.TYPE_ARRAY, title="问题", description="问题", + items=openapi.Schema(type=openapi.TYPE_OBJECT, + required=['role', 'content'], + properties={ + 'content': openapi.Schema( + type=openapi.TYPE_STRING, + title="问题内容", default=''), + 'role': openapi.Schema( + type=openapi.TYPE_STRING, + title='角色', default="user") + } + )), + 'chat_id': openapi.Schema(type=openapi.TYPE_STRING, title="对话id"), + 're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=False), + 'stream': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="流式输出", default=True) + + }) + + class ChatApi(ApiMixin): @staticmethod def get_request_body_api(): diff --git a/apps/application/urls.py b/apps/application/urls.py index 5a41bc59c..2555b7dcd 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -42,6 +42,8 @@ urlpatterns = [ path("application//chat/client/", views.ChatView.ClientChatHistoryPage.Operate.as_view()), path('application//chat/export', views.ChatView.Export.as_view(), name='export'), + path('application//chat/completions', views.Openai.as_view(), + name='application/chat_completions'), path('application//chat', views.ChatView.as_view(), name='chats'), path('application//chat//', views.ChatView.Page.as_view()), path('application//chat/', views.ChatView.Operate.as_view()), @@ -64,7 +66,9 @@ urlpatterns = [ 'application//chat//chat_record//dataset//document_id//improve/', views.ChatView.ChatRecord.Improve.Operate.as_view(), name=''), - path('application//speech_to_text', views.Application.SpeechToText.as_view(), name='application/audio'), - path('application//text_to_speech', views.Application.TextToSpeech.as_view(), name='application/audio'), + path('application//speech_to_text', views.Application.SpeechToText.as_view(), + name='application/audio'), + path('application//text_to_speech', views.Application.TextToSpeech.as_view(), + name='application/audio'), ] diff --git a/apps/application/views/chat_views.py b/apps/application/views/chat_views.py index 577c08838..288e8a1fc 100644 --- a/apps/application/views/chat_views.py +++ b/apps/application/views/chat_views.py @@ -6,16 +6,17 @@ @date:2023/11/14 9:53 @desc: """ + from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action from rest_framework.request import Request from rest_framework.views import APIView -from application.serializers.chat_message_serializers import ChatMessageSerializer +from application.serializers.chat_message_serializers import ChatMessageSerializer, OpenAIChatSerializer from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi, \ - ChatClientHistoryApi -from common.auth import TokenAuth, has_permissions + ChatClientHistoryApi, OpenAIChatApi +from common.auth import TokenAuth, has_permissions, OpenAIKeyAuth from common.constants.authentication_type import AuthenticationType from common.constants.permission_constants import Permission, Group, Operate, \ RoleConstants, ViewPermission, CompareConstants @@ -23,6 +24,19 @@ from common.response import result from common.util.common import query_params_to_single_dict +class Openai(APIView): + authentication_classes = [OpenAIKeyAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="openai接口对话", + operation_id="openai接口对话", + request_body=OpenAIChatApi.get_request_body_api(), + tags=["openai对话"]) + def post(self, request: Request, application_id: str): + return OpenAIChatSerializer(data={'application_id': application_id, 'client_id': request.auth.client_id, + 'client_type': request.auth.client_type}).chat(request.data) + + class ChatView(APIView): authentication_classes = [TokenAuth] diff --git a/apps/common/auth/authenticate.py b/apps/common/auth/authenticate.py index 424f85965..a27cdb144 100644 --- a/apps/common/auth/authenticate.py +++ b/apps/common/auth/authenticate.py @@ -52,6 +52,26 @@ class TokenDetails: return self.token_details +class OpenAIKeyAuth(TokenAuthentication): + def authenticate(self, request): + auth = request.META.get('HTTP_AUTHORIZATION') + auth = auth.replace('Bearer ', '') + # 未认证 + if auth is None: + raise AppAuthenticationFailed(1003, '未登录,请先登录') + try: + token_details = TokenDetails(auth) + for handle in handles: + if handle.support(request, auth, token_details.get_token_details): + return handle.handle(request, auth, token_details.get_token_details) + raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户") + except Exception as e: + traceback.format_exc() + if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed): + raise e + raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户") + + class TokenAuth(TokenAuthentication): # 重新 authenticate 方法,自定义认证规则 def authenticate(self, request): diff --git a/apps/common/exception/app_exception.py b/apps/common/exception/app_exception.py index 3646efb0c..b8f5602e7 100644 --- a/apps/common/exception/app_exception.py +++ b/apps/common/exception/app_exception.py @@ -73,3 +73,11 @@ class AppChatNumOutOfBoundsFailed(AppApiException): def __init__(self, code, message): self.code = code self.message = message + + +class ChatException(AppApiException): + status_code = 500 + + def __init__(self, code, message): + self.code = code + self.message = message diff --git a/apps/common/handle/base_to_response.py b/apps/common/handle/base_to_response.py new file mode 100644 index 000000000..05af57cb9 --- /dev/null +++ b/apps/common/handle/base_to_response.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_to_response.py + @date:2024/9/6 16:04 + @desc: +""" +from abc import ABC, abstractmethod + +from rest_framework import status + + +class BaseToResponse(ABC): + + @abstractmethod + def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, + _status=status.HTTP_200_OK): + pass + + @abstractmethod + def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): + pass + + @staticmethod + def format_stream_chunk(response_str): + return 'data: ' + response_str + '\n\n' diff --git a/apps/common/handle/impl/response/openai_to_response.py b/apps/common/handle/impl/response/openai_to_response.py new file mode 100644 index 000000000..791224aff --- /dev/null +++ b/apps/common/handle/impl/response/openai_to_response.py @@ -0,0 +1,42 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: openai_to_response.py + @date:2024/9/6 16:08 + @desc: +""" +import datetime + +from django.http import JsonResponse +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletionChunk, ChatCompletionMessage, ChatCompletion +from openai.types.chat.chat_completion import Choice as BlockChoice +from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta +from rest_framework import status + +from common.handle.base_to_response import BaseToResponse + + +class OpenaiToResponse(BaseToResponse): + def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, + _status=status.HTTP_200_OK): + data = ChatCompletion(id=chat_record_id, choices=[ + BlockChoice(finish_reason='stop', index=0, chat_id=chat_id, + message=ChatCompletionMessage(role='assistant', content=content))], + created=datetime.datetime.now().second, model='', object='chat.completion', + usage=CompletionUsage(completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=completion_tokens + prompt_tokens) + ).dict() + return JsonResponse(data=data, status=status) + + def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): + chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk', + created=datetime.datetime.now().second, choices=[ + Choice(delta=ChoiceDelta(content=content, chat_id=chat_id), finish_reason='stop' if is_end else None, + index=0)], + usage=CompletionUsage(completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=completion_tokens + prompt_tokens)).json() + return super().format_stream_chunk(chunk) diff --git a/apps/common/handle/impl/response/system_to_response.py b/apps/common/handle/impl/response/system_to_response.py new file mode 100644 index 000000000..1ec980633 --- /dev/null +++ b/apps/common/handle/impl/response/system_to_response.py @@ -0,0 +1,26 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: system_to_response.py + @date:2024/9/6 18:03 + @desc: +""" +import json + +from rest_framework import status + +from common.handle.base_to_response import BaseToResponse +from common.response import result + + +class SystemToResponse(BaseToResponse): + def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, + _status=status.HTTP_200_OK): + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': content, 'is_end': is_end}, response_status=_status, code=_status) + + def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): + chunk = json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': content, 'is_end': is_end}) + return super().format_stream_chunk(chunk) diff --git a/apps/setting/models_provider/impl/local_model_provider/model/reranker.py b/apps/setting/models_provider/impl/local_model_provider/model/reranker.py index a2f932329..4ada1eb35 100644 --- a/apps/setting/models_provider/impl/local_model_provider/model/reranker.py +++ b/apps/setting/models_provider/impl/local_model_provider/model/reranker.py @@ -50,6 +50,8 @@ class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor): def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ Sequence[Document]: + if documents is None or len(documents) == 0: + return [] bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' res = requests.post( f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/compress_documents', @@ -85,6 +87,8 @@ class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor): def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ Sequence[Document]: + if documents is None or len(documents) == 0: + return [] with torch.no_grad(): inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True, truncation=True, return_tensors='pt', max_length=512) diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index cdd187d9c..ff80b0e50 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -6,9 +6,11 @@ @date:2024/4/18 15:28 @desc: """ -from typing import List, Dict +from typing import List, Dict, Optional, Any +from langchain_core.language_models import LanguageModelInput from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.runnables import RunnableConfig from langchain_openai.chat_models import ChatOpenAI from common.config.tokenizer_manage_config import TokenizerManage