feat: 支持openai接口 #353 (#1128)

This commit is contained in:
shaohuzhang1 2024-09-09 14:47:25 +08:00 committed by GitHub
parent f882249216
commit f9a76d7948
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 317 additions and 48 deletions

View File

@ -11,14 +11,18 @@ from functools import reduce
from typing import List, Type, Dict
from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.system_to_response import SystemToResponse
class PipelineManage:
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]]):
def __init__(self, step_list: List[Type[IBaseChatPipelineStep]],
base_to_response: BaseToResponse = SystemToResponse()):
# 步骤执行器
self.step_list = [step() for step in step_list]
# 上下文
self.context = {'message_tokens': 0, 'answer_tokens': 0}
self.base_to_response = base_to_response
def run(self, context: Dict = None):
self.context['start_time'] = time.time()
@ -33,13 +37,21 @@ class PipelineManage:
filter(lambda r: r is not None,
[row.get_details(self) for row in self.step_list])], {})
def get_base_to_response(self):
return self.base_to_response
class builder:
def __init__(self):
self.step_list: List[Type[IBaseChatPipelineStep]] = []
self.base_to_response = SystemToResponse()
def append_step(self, step: Type[IBaseChatPipelineStep]):
self.step_list.append(step)
return self
def add_base_to_response(self, base_to_response: BaseToResponse):
self.base_to_response = base_to_response
return self
def build(self):
return PipelineManage(step_list=self.step_list)
return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response)

View File

@ -6,7 +6,6 @@
@date2024/1/9 18:25
@desc: 对话step Base实现
"""
import json
import logging
import time
import traceback
@ -19,13 +18,13 @@ from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from langchain.schema.messages import HumanMessage, AIMessage
from langchain_core.messages import AIMessageChunk
from rest_framework import status
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.pipeline_manage import PipelineManage
from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler
from application.models.api_key_model import ApplicationPublicAccessClient
from common.constants.authentication_type import AuthenticationType
from common.response import result
from setting.models_provider.tools import get_model_instance_by_model_user_id
@ -66,9 +65,9 @@ def event_content(response,
try:
for chunk in response:
all_text += chunk.content
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': chunk.content, 'is_end': False}) + "\n\n"
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), chunk.content,
False,
0, 0)
# 获取token
if is_ai_chat:
try:
@ -83,8 +82,8 @@ def event_content(response,
write_context(step, manage, request_token, response_token, all_text)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text, client_id)
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': '', 'is_end': True}) + "\n\n"
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), '', True,
request_token, response_token)
add_access_num(client_id, client_type)
except Exception as e:
logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}')
@ -93,8 +92,7 @@ def event_content(response,
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, step, padding_problem_text, client_id)
add_access_num(client_id, client_type)
yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': all_text, 'is_end': True}) + "\n\n"
yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), all_text, True, 0, 0)
class BaseChatStep(IChatStep):
@ -234,13 +232,14 @@ class BaseChatStep(IChatStep):
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
chat_result.content, manage, self, padding_problem_text, client_id)
add_access_num(client_id, client_type)
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': chat_result.content, 'is_end': True})
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id),
chat_result.content, True,
request_token, response_token)
except Exception as e:
all_text = '异常' + str(e)
write_context(self, manage, 0, 0, all_text)
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
all_text, manage, self, padding_problem_text, client_id)
add_access_num(client_id, client_type)
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': all_text, 'is_end': True})
return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), all_text, True, 0,
0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR)

View File

@ -11,14 +11,16 @@ from functools import reduce
from typing import List, Dict
from django.db.models import QuerySet
from langchain_core.messages import AIMessage
from langchain_core.prompts import PromptTemplate
from rest_framework import status
from rest_framework.exceptions import ErrorDetail, ValidationError
from application.flow import tools
from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult
from application.flow.step_node import get_node
from common.exception.app_exception import AppApiException
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.system_to_response import SystemToResponse
from function_lib.models.function import FunctionLib
from setting.models import Model
from setting.models_provider import get_model_credential
@ -163,7 +165,8 @@ class Flow:
class WorkflowManage:
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler):
def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler,
base_to_response: BaseToResponse = SystemToResponse()):
self.params = params
self.flow = flow
self.context = {}
@ -172,6 +175,7 @@ class WorkflowManage:
self.current_node = None
self.current_result = None
self.answer = ""
self.base_to_response = base_to_response
def run(self):
if self.params.get('stream'):
@ -189,13 +193,19 @@ class WorkflowManage:
if result is not None:
list(result)
if not self.has_next_node(self.current_result):
return tools.to_response_simple(self.params['chat_id'], self.params['chat_record_id'],
AIMessage(self.answer), self,
self.work_flow_post_handler)
details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
return self.base_to_response.to_block_response(self.params['chat_id'],
self.params['chat_record_id'], self.answer, True
, message_tokens, answer_tokens)
except Exception as e:
return tools.to_response(self.params['chat_id'], self.params['chat_record_id'],
AIMessage(str(e)), self, self.current_node.get_write_error_context(e),
self.work_flow_post_handler)
self.current_node.get_write_error_context(e)
return self.base_to_response.to_block_response(self.params['chat_id'], self.params['chat_record_id'],
str(e), True,
0, 0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR)
def run_stream(self):
return tools.to_stream_response_simple(self.stream_event())
@ -208,16 +218,29 @@ class WorkflowManage:
self.current_node.valid_args(self.current_node.node_params, self.current_node.workflow_params)
self.current_result = self.current_node.run()
result = self.current_result.write_context(self.current_node, self)
has_next_node = self.has_next_node(self.current_result)
if result is not None:
if self.is_result():
for r in result:
yield self.get_chunk_content(r)
yield self.get_chunk_content('\n')
self.answer += '\n'
yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
r, False, 0, 0)
if has_next_node:
yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
'\n', False, 0, 0)
self.answer += '\n'
else:
list(result)
if not self.has_next_node(self.current_result):
yield self.get_chunk_content('', True)
if not has_next_node:
details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'],
self.params['chat_record_id'],
'', True, message_tokens, answer_tokens)
break
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer,
@ -228,7 +251,8 @@ class WorkflowManage:
self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'],
self.answer,
self)
yield self.get_chunk_content(str(e), True)
yield self.base_to_response.to_stream_chunk_response(self.params['chat_id'], self.params['chat_record_id'],
str(e), True, 0, 0)
def is_result(self):
"""

View File

@ -7,7 +7,7 @@
@desc:
"""
import uuid
from typing import List
from typing import List, Dict
from uuid import UUID
from django.core.cache import caches
@ -27,7 +27,10 @@ from application.models import ChatRecord, Chat, Application, ApplicationDataset
WorkFlowVersion
from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken
from common.constants.authentication_type import AuthenticationType
from common.exception.app_exception import AppApiException, AppChatNumOutOfBoundsFailed
from common.exception.app_exception import AppChatNumOutOfBoundsFailed, ChatException
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.openai_to_response import OpenaiToResponse
from common.handle.impl.response.system_to_response import SystemToResponse
from common.util.field_message import ErrMessage
from common.util.split_model import flat_map
from dataset.models import Paragraph, Document
@ -145,6 +148,58 @@ def get_post_handler(chat_info: ChatInfo):
return PostHandler()
class OpenAIMessage(serializers.Serializer):
content = serializers.CharField(required=True, error_messages=ErrMessage.char('内容'))
role = serializers.CharField(required=True, error_messages=ErrMessage.char('角色'))
class OpenAIInstanceSerializer(serializers.Serializer):
messages = serializers.ListField(child=OpenAIMessage())
chat_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char("对话id"))
re_chat = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("重新生成"))
stream = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("流式输出"))
class OpenAIChatSerializer(serializers.Serializer):
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id"))
client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型"))
@staticmethod
def get_message(instance):
return instance.get('messages')[-1].get('content')
@staticmethod
def generate_chat(chat_id, application_id, message, client_id):
if chat_id is None:
chat_id = str(uuid.uuid1())
chat = QuerySet(Chat).filter(id=chat_id).first()
if chat is None:
Chat(id=chat_id, application_id=application_id, abstract=message, client_id=client_id).save()
return chat_id
def chat(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
OpenAIInstanceSerializer(data=instance).is_valid(raise_exception=True)
chat_id = instance.get('chat_id')
message = self.get_message(instance)
re_chat = instance.get('re_chat', False)
stream = instance.get('stream', False)
application_id = self.data.get('application_id')
client_id = self.data.get('client_id')
client_type = self.data.get('client_type')
chat_id = self.generate_chat(chat_id, application_id, message, client_id)
return ChatMessageSerializer(
data={'chat_id': chat_id, 'message': message,
're_chat': re_chat,
'stream': stream,
'application_id': application_id,
'client_id': client_id,
'client_type': client_type}).chat(
base_to_response=OpenaiToResponse())
class ChatMessageSerializer(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id"))
message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题"))
@ -157,6 +212,10 @@ class ChatMessageSerializer(serializers.Serializer):
def is_valid_application_workflow(self, *, raise_exception=False):
self.is_valid_intraday_access_num()
def is_valid_chat_id(self, chat_info: ChatInfo):
if self.data.get('application_id') != str(chat_info.application.id):
raise ChatException(500, "会话不存在")
def is_valid_intraday_access_num(self):
if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value:
access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first()
@ -181,12 +240,12 @@ class ChatMessageSerializer(serializers.Serializer):
if model is None:
return chat_info
if model.status == Status.ERROR:
raise AppApiException(500, "当前模型不可用")
raise ChatException(500, "当前模型不可用")
if model.status == Status.DOWNLOAD:
raise AppApiException(500, "模型正在下载中,请稍后再发起对话")
raise ChatException(500, "模型正在下载中,请稍后再发起对话")
return chat_info
def chat_simple(self, chat_info: ChatInfo):
def chat_simple(self, chat_info: ChatInfo, base_to_response):
message = self.data.get('message')
re_chat = self.data.get('re_chat')
stream = self.data.get('stream')
@ -200,6 +259,7 @@ class ChatMessageSerializer(serializers.Serializer):
pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep)
.append_step(BaseGenerateHumanMessageStep)
.append_step(BaseChatStep)
.add_base_to_response(base_to_response)
.build())
exclude_paragraph_id_list = []
# 相同问题是否需要排除已经查询到的段落
@ -217,7 +277,7 @@ class ChatMessageSerializer(serializers.Serializer):
pipeline_message.run(params)
return pipeline_message.context['chat_result']
def chat_work_flow(self, chat_info: ChatInfo):
def chat_work_flow(self, chat_info: ChatInfo, base_to_response):
message = self.data.get('message')
re_chat = self.data.get('re_chat')
stream = self.data.get('stream')
@ -229,19 +289,21 @@ class ChatMessageSerializer(serializers.Serializer):
'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()),
'stream': stream,
're_chat': re_chat,
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type))
'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type),
base_to_response)
r = work_flow_manage.run()
return r
def chat(self):
def chat(self, base_to_response: BaseToResponse = SystemToResponse()):
super().is_valid(raise_exception=True)
chat_info = self.get_chat_info()
self.is_valid_chat_id(chat_info)
if chat_info.application.type == ApplicationTypeChoices.SIMPLE:
self.is_valid_application_simple(raise_exception=True, chat_info=chat_info),
return self.chat_simple(chat_info)
return self.chat_simple(chat_info, base_to_response)
else:
self.is_valid_application_workflow(raise_exception=True)
return self.chat_work_flow(chat_info)
return self.chat_work_flow(chat_info, base_to_response)
def get_chat_info(self):
self.is_valid(raise_exception=True)
@ -256,10 +318,10 @@ class ChatMessageSerializer(serializers.Serializer):
def re_open_chat(self, chat_id: str):
chat = QuerySet(Chat).filter(id=chat_id).first()
if chat is None:
raise AppApiException(500, "会话不存在")
raise ChatException(500, "会话不存在")
application = QuerySet(Application).filter(id=chat.application_id).first()
if application is None:
raise AppApiException(500, "应用不存在")
raise ChatException(500, "应用不存在")
if application.type == ApplicationTypeChoices.SIMPLE:
return self.re_open_chat_simple(chat_id, application)
else:
@ -289,7 +351,7 @@ class ChatMessageSerializer(serializers.Serializer):
work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application.id).order_by(
'-create_time')[0:1].first()
if work_flow_version is None:
raise AppApiException(500, "应用未发布,请发布后再使用")
raise ChatException(500, "应用未发布,请发布后再使用")
chat_info = ChatInfo(chat_id, [], [], application, work_flow_version)
chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5])

View File

@ -23,6 +23,31 @@ class ChatClientHistoryApi(ApiMixin):
]
class OpenAIChatApi(ApiMixin):
@staticmethod
def get_request_body_api():
return openapi.Schema(type=openapi.TYPE_OBJECT,
required=['message'],
properties={
'messages': openapi.Schema(type=openapi.TYPE_ARRAY, title="问题", description="问题",
items=openapi.Schema(type=openapi.TYPE_OBJECT,
required=['role', 'content'],
properties={
'content': openapi.Schema(
type=openapi.TYPE_STRING,
title="问题内容", default=''),
'role': openapi.Schema(
type=openapi.TYPE_STRING,
title='角色', default="user")
}
)),
'chat_id': openapi.Schema(type=openapi.TYPE_STRING, title="对话id"),
're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=False),
'stream': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="流式输出", default=True)
})
class ChatApi(ApiMixin):
@staticmethod
def get_request_body_api():

View File

@ -42,6 +42,8 @@ urlpatterns = [
path("application/<str:application_id>/chat/client/<chat_id>",
views.ChatView.ClientChatHistoryPage.Operate.as_view()),
path('application/<str:application_id>/chat/export', views.ChatView.Export.as_view(), name='export'),
path('application/<str:application_id>/chat/completions', views.Openai.as_view(),
name='application/chat_completions'),
path('application/<str:application_id>/chat', views.ChatView.as_view(), name='chats'),
path('application/<str:application_id>/chat/<int:current_page>/<int:page_size>', views.ChatView.Page.as_view()),
path('application/<str:application_id>/chat/<chat_id>', views.ChatView.Operate.as_view()),
@ -64,7 +66,9 @@ urlpatterns = [
'application/<str:application_id>/chat/<chat_id>/chat_record/<str:chat_record_id>/dataset/<str:dataset_id>/document_id/<str:document_id>/improve/<str:paragraph_id>',
views.ChatView.ChatRecord.Improve.Operate.as_view(),
name=''),
path('application/<str:application_id>/speech_to_text', views.Application.SpeechToText.as_view(), name='application/audio'),
path('application/<str:application_id>/text_to_speech', views.Application.TextToSpeech.as_view(), name='application/audio'),
path('application/<str:application_id>/speech_to_text', views.Application.SpeechToText.as_view(),
name='application/audio'),
path('application/<str:application_id>/text_to_speech', views.Application.TextToSpeech.as_view(),
name='application/audio'),
]

View File

@ -6,16 +6,17 @@
@date2023/11/14 9:53
@desc:
"""
from drf_yasg.utils import swagger_auto_schema
from rest_framework.decorators import action
from rest_framework.request import Request
from rest_framework.views import APIView
from application.serializers.chat_message_serializers import ChatMessageSerializer
from application.serializers.chat_message_serializers import ChatMessageSerializer, OpenAIChatSerializer
from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer
from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi, \
ChatClientHistoryApi
from common.auth import TokenAuth, has_permissions
ChatClientHistoryApi, OpenAIChatApi
from common.auth import TokenAuth, has_permissions, OpenAIKeyAuth
from common.constants.authentication_type import AuthenticationType
from common.constants.permission_constants import Permission, Group, Operate, \
RoleConstants, ViewPermission, CompareConstants
@ -23,6 +24,19 @@ from common.response import result
from common.util.common import query_params_to_single_dict
class Openai(APIView):
authentication_classes = [OpenAIKeyAuth]
@action(methods=['POST'], detail=False)
@swagger_auto_schema(operation_summary="openai接口对话",
operation_id="openai接口对话",
request_body=OpenAIChatApi.get_request_body_api(),
tags=["openai对话"])
def post(self, request: Request, application_id: str):
return OpenAIChatSerializer(data={'application_id': application_id, 'client_id': request.auth.client_id,
'client_type': request.auth.client_type}).chat(request.data)
class ChatView(APIView):
authentication_classes = [TokenAuth]

View File

@ -52,6 +52,26 @@ class TokenDetails:
return self.token_details
class OpenAIKeyAuth(TokenAuthentication):
def authenticate(self, request):
auth = request.META.get('HTTP_AUTHORIZATION')
auth = auth.replace('Bearer ', '')
# 未认证
if auth is None:
raise AppAuthenticationFailed(1003, '未登录,请先登录')
try:
token_details = TokenDetails(auth)
for handle in handles:
if handle.support(request, auth, token_details.get_token_details):
return handle.handle(request, auth, token_details.get_token_details)
raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户")
except Exception as e:
traceback.format_exc()
if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed):
raise e
raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户")
class TokenAuth(TokenAuthentication):
# 重新 authenticate 方法,自定义认证规则
def authenticate(self, request):

View File

@ -73,3 +73,11 @@ class AppChatNumOutOfBoundsFailed(AppApiException):
def __init__(self, code, message):
self.code = code
self.message = message
class ChatException(AppApiException):
status_code = 500
def __init__(self, code, message):
self.code = code
self.message = message

View File

@ -0,0 +1,27 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file base_to_response.py
@date2024/9/6 16:04
@desc:
"""
from abc import ABC, abstractmethod
from rest_framework import status
class BaseToResponse(ABC):
@abstractmethod
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens,
_status=status.HTTP_200_OK):
pass
@abstractmethod
def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens):
pass
@staticmethod
def format_stream_chunk(response_str):
return 'data: ' + response_str + '\n\n'

View File

@ -0,0 +1,42 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file openai_to_response.py
@date2024/9/6 16:08
@desc:
"""
import datetime
from django.http import JsonResponse
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessage, ChatCompletion
from openai.types.chat.chat_completion import Choice as BlockChoice
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
from rest_framework import status
from common.handle.base_to_response import BaseToResponse
class OpenaiToResponse(BaseToResponse):
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens,
_status=status.HTTP_200_OK):
data = ChatCompletion(id=chat_record_id, choices=[
BlockChoice(finish_reason='stop', index=0, chat_id=chat_id,
message=ChatCompletionMessage(role='assistant', content=content))],
created=datetime.datetime.now().second, model='', object='chat.completion',
usage=CompletionUsage(completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
total_tokens=completion_tokens + prompt_tokens)
).dict()
return JsonResponse(data=data, status=status)
def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens):
chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk',
created=datetime.datetime.now().second, choices=[
Choice(delta=ChoiceDelta(content=content, chat_id=chat_id), finish_reason='stop' if is_end else None,
index=0)],
usage=CompletionUsage(completion_tokens=completion_tokens,
prompt_tokens=prompt_tokens,
total_tokens=completion_tokens + prompt_tokens)).json()
return super().format_stream_chunk(chunk)

View File

@ -0,0 +1,26 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file system_to_response.py
@date2024/9/6 18:03
@desc:
"""
import json
from rest_framework import status
from common.handle.base_to_response import BaseToResponse
from common.response import result
class SystemToResponse(BaseToResponse):
def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens,
_status=status.HTTP_200_OK):
return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': content, 'is_end': is_end}, response_status=_status, code=_status)
def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens):
chunk = json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True,
'content': content, 'is_end': is_end})
return super().format_stream_chunk(chunk)

View File

@ -50,6 +50,8 @@ class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
res = requests.post(
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/compress_documents',
@ -85,6 +87,8 @@ class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor):
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
Sequence[Document]:
if documents is None or len(documents) == 0:
return []
with torch.no_grad():
inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True,
truncation=True, return_tensors='pt', max_length=512)

View File

@ -6,9 +6,11 @@
@date2024/4/18 15:28
@desc:
"""
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.runnables import RunnableConfig
from langchain_openai.chat_models import ChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage