mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 01:33:05 +00:00
fix: ai对话未计算tokens #247
This commit is contained in:
parent
25aa4cd2b7
commit
72f43cff51
|
|
@ -154,6 +154,7 @@ class BaseChatStep(IChatStep):
|
|||
chat_result = iter(directly_return_chunk_list)
|
||||
else:
|
||||
chat_result = chat_model.stream(message_list)
|
||||
is_ai_chat = True
|
||||
else:
|
||||
chat_result = chat_model.stream(message_list)
|
||||
is_ai_chat = True
|
||||
|
|
@ -187,8 +188,18 @@ class BaseChatStep(IChatStep):
|
|||
'status') == 'designated_answer':
|
||||
chat_result = AIMessage(content=no_references_setting.get('value'))
|
||||
else:
|
||||
chat_result = chat_model.invoke(message_list)
|
||||
is_ai_chat = True
|
||||
if paragraph_list is not None and len(paragraph_list) > 0:
|
||||
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
|
||||
for paragraph in paragraph_list if
|
||||
paragraph.hit_handling_method == 'directly_return']
|
||||
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
|
||||
chat_result = iter(directly_return_chunk_list)
|
||||
else:
|
||||
chat_result = chat_model.invoke(message_list)
|
||||
is_ai_chat = True
|
||||
else:
|
||||
chat_result = chat_model.invoke(message_list)
|
||||
is_ai_chat = True
|
||||
chat_record_id = uuid.uuid1()
|
||||
if is_ai_chat:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.1.13 on 2024-04-25 11:28
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('application', '0003_application_icon'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name='applicationaccesstoken',
|
||||
name='show_source',
|
||||
field=models.BooleanField(default=False, verbose_name='是否显示知识来源'),
|
||||
),
|
||||
]
|
||||
|
|
@ -39,6 +39,7 @@ class ApplicationAccessToken(AppModelMixin):
|
|||
white_list = ArrayField(verbose_name="白名单列表",
|
||||
base_field=models.CharField(max_length=128, blank=True)
|
||||
, default=list)
|
||||
show_source = models.BooleanField(default=False, verbose_name="是否显示知识来源")
|
||||
|
||||
class Meta:
|
||||
db_table = "application_access_token"
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from common.config.embedding_config import VectorStore, EmbeddingModel
|
|||
from common.constants.authentication_type import AuthenticationType
|
||||
from common.db.search import get_dynamics_model, native_search, native_page_search
|
||||
from common.db.sql_execute import select_list
|
||||
from common.exception.app_exception import AppApiException, NotFound404
|
||||
from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed
|
||||
from common.field.common import UploadedImageField
|
||||
from common.util.field_message import ErrMessage
|
||||
from common.util.file_util import get_file_content
|
||||
|
|
@ -170,7 +170,9 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
white_list = serializers.ListSerializer(required=False, child=serializers.CharField(required=True,
|
||||
error_messages=ErrMessage.char(
|
||||
"白名单")),
|
||||
error_messages=ErrMessage.list("白名单列表"))
|
||||
error_messages=ErrMessage.list("白名单列表")),
|
||||
show_source = serializers.BooleanField(required=False,
|
||||
error_messages=ErrMessage.boolean("是否显示知识来源"))
|
||||
|
||||
def edit(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
|
|
@ -190,6 +192,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
application_access_token.white_active = instance.get("white_active")
|
||||
if 'white_list' in instance and instance.get('white_list') is not None:
|
||||
application_access_token.white_list = instance.get('white_list')
|
||||
if 'show_source' in instance and instance.get('show_source') is not None:
|
||||
application_access_token.show_source = instance.get('show_source')
|
||||
application_access_token.save()
|
||||
return self.one(with_valid=False)
|
||||
|
||||
|
|
@ -210,7 +214,8 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
"is_active": application_access_token.is_active,
|
||||
'access_num': application_access_token.access_num,
|
||||
'white_active': application_access_token.white_active,
|
||||
'white_list': application_access_token.white_list
|
||||
'white_list': application_access_token.white_list,
|
||||
'show_source': application_access_token.show_source
|
||||
}
|
||||
|
||||
class Authentication(serializers.Serializer):
|
||||
|
|
@ -474,8 +479,12 @@ class ApplicationSerializer(serializers.Serializer):
|
|||
self.is_valid()
|
||||
application_id = self.data.get("application_id")
|
||||
application = QuerySet(Application).get(id=application_id)
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application.id).first()
|
||||
if application_access_token is None:
|
||||
raise AppUnauthorizedFailed(500, "非法用户")
|
||||
return ApplicationSerializer.Query.reset_application(
|
||||
ApplicationSerializer.ApplicationModel(application).data)
|
||||
{**ApplicationSerializer.ApplicationModel(application).data,
|
||||
'show_source': application_access_token.show_source})
|
||||
|
||||
def edit(self, instance: Dict, with_valid=True):
|
||||
if with_valid:
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from django.http import HttpResponse
|
|||
from rest_framework import serializers
|
||||
|
||||
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
|
||||
from application.models.api_key_model import ApplicationAccessToken
|
||||
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
|
||||
ModelSettingSerializer
|
||||
from application.serializers.chat_message_serializers import ChatInfo
|
||||
|
|
@ -277,17 +278,27 @@ class ChatRecordSerializerModel(serializers.ModelSerializer):
|
|||
class ChatRecordSerializer(serializers.Serializer):
|
||||
class Operate(serializers.Serializer):
|
||||
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
|
||||
|
||||
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
|
||||
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))
|
||||
|
||||
def is_valid(self, *, raise_exception=False):
|
||||
super().is_valid(raise_exception=True)
|
||||
application_access_token = QuerySet(ApplicationAccessToken).filter(
|
||||
application_id=self.data.get('application_id')).first()
|
||||
if application_access_token is None:
|
||||
raise AppApiException(500, '不存在的应用认证信息')
|
||||
if not application_access_token.show_source:
|
||||
raise AppApiException(500, '未开启显示知识来源')
|
||||
|
||||
def get_chat_record(self):
|
||||
chat_record_id = self.data.get('chat_record_id')
|
||||
chat_id = self.data.get('chat_id')
|
||||
chat_info: ChatInfo = chat_cache.get(chat_id)
|
||||
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
|
||||
chat_record.id == uuid.UUID(chat_record_id)]
|
||||
if chat_record_list is not None and len(chat_record_list):
|
||||
return chat_record_list[-1]
|
||||
if chat_info is not None:
|
||||
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
|
||||
chat_record.id == uuid.UUID(chat_record_id)]
|
||||
if chat_record_list is not None and len(chat_record_list):
|
||||
return chat_record_list[-1]
|
||||
return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
|
||||
|
||||
def one(self, with_valid=True):
|
||||
|
|
|
|||
|
|
@ -132,6 +132,8 @@ class ApplicationApi(ApiMixin):
|
|||
'white_list': openapi.Schema(type=openapi.TYPE_ARRAY,
|
||||
items=openapi.Schema(type=openapi.TYPE_STRING), title="白名单列表",
|
||||
description="白名单列表"),
|
||||
'show_source': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否显示知识来源",
|
||||
description="是否显示知识来源"),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -172,7 +172,8 @@ class ChatView(APIView):
|
|||
tags=["应用/对话日志"]
|
||||
)
|
||||
@has_permissions(
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
|
||||
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
|
||||
RoleConstants.APPLICATION_ACCESS_TOKEN],
|
||||
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
|
||||
dynamic_tag=keywords.get('application_id'))])
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue