fix: ai对话未计算tokens #247

This commit is contained in:
shaohuzhang1 2024-04-25 14:05:59 +08:00 committed by GitHub
parent 25aa4cd2b7
commit 72f43cff51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 65 additions and 12 deletions

View File

@ -154,6 +154,7 @@ class BaseChatStep(IChatStep):
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
else:
chat_result = chat_model.stream(message_list)
is_ai_chat = True
@ -187,8 +188,18 @@ class BaseChatStep(IChatStep):
'status') == 'designated_answer':
chat_result = AIMessage(content=no_references_setting.get('value'))
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
if paragraph_list is not None and len(paragraph_list) > 0:
directly_return_chunk_list = [AIMessageChunk(content=paragraph.title + "\n" + paragraph.content)
for paragraph in paragraph_list if
paragraph.hit_handling_method == 'directly_return']
if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0:
chat_result = iter(directly_return_chunk_list)
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
else:
chat_result = chat_model.invoke(message_list)
is_ai_chat = True
chat_record_id = uuid.uuid1()
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)

View File

@ -0,0 +1,18 @@
# Generated by Django 4.1.13 on 2024-04-25 11:28
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('application', '0003_application_icon'),
]
operations = [
migrations.AddField(
model_name='applicationaccesstoken',
name='show_source',
field=models.BooleanField(default=False, verbose_name='是否显示知识来源'),
),
]

View File

@ -39,6 +39,7 @@ class ApplicationAccessToken(AppModelMixin):
white_list = ArrayField(verbose_name="白名单列表",
base_field=models.CharField(max_length=128, blank=True)
, default=list)
show_source = models.BooleanField(default=False, verbose_name="是否显示知识来源")
class Meta:
db_table = "application_access_token"

View File

@ -28,7 +28,7 @@ from common.config.embedding_config import VectorStore, EmbeddingModel
from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
from common.exception.app_exception import AppApiException, NotFound404
from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed
from common.field.common import UploadedImageField
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
@ -170,7 +170,9 @@ class ApplicationSerializer(serializers.Serializer):
white_list = serializers.ListSerializer(required=False, child=serializers.CharField(required=True,
error_messages=ErrMessage.char(
"白名单")),
error_messages=ErrMessage.list("白名单列表"))
error_messages=ErrMessage.list("白名单列表")),
show_source = serializers.BooleanField(required=False,
error_messages=ErrMessage.boolean("是否显示知识来源"))
def edit(self, instance: Dict, with_valid=True):
if with_valid:
@ -190,6 +192,8 @@ class ApplicationSerializer(serializers.Serializer):
application_access_token.white_active = instance.get("white_active")
if 'white_list' in instance and instance.get('white_list') is not None:
application_access_token.white_list = instance.get('white_list')
if 'show_source' in instance and instance.get('show_source') is not None:
application_access_token.show_source = instance.get('show_source')
application_access_token.save()
return self.one(with_valid=False)
@ -210,7 +214,8 @@ class ApplicationSerializer(serializers.Serializer):
"is_active": application_access_token.is_active,
'access_num': application_access_token.access_num,
'white_active': application_access_token.white_active,
'white_list': application_access_token.white_list
'white_list': application_access_token.white_list,
'show_source': application_access_token.show_source
}
class Authentication(serializers.Serializer):
@ -474,8 +479,12 @@ class ApplicationSerializer(serializers.Serializer):
self.is_valid()
application_id = self.data.get("application_id")
application = QuerySet(Application).get(id=application_id)
application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application.id).first()
if application_access_token is None:
raise AppUnauthorizedFailed(500, "非法用户")
return ApplicationSerializer.Query.reset_application(
ApplicationSerializer.ApplicationModel(application).data)
{**ApplicationSerializer.ApplicationModel(application).data,
'show_source': application_access_token.show_source})
def edit(self, instance: Dict, with_valid=True):
if with_valid:

View File

@ -23,6 +23,7 @@ from django.http import HttpResponse
from rest_framework import serializers
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
from application.models.api_key_model import ApplicationAccessToken
from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \
ModelSettingSerializer
from application.serializers.chat_message_serializers import ChatInfo
@ -277,17 +278,27 @@ class ChatRecordSerializerModel(serializers.ModelSerializer):
class ChatRecordSerializer(serializers.Serializer):
class Operate(serializers.Serializer):
chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id"))
application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id"))
chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id"))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
application_access_token = QuerySet(ApplicationAccessToken).filter(
application_id=self.data.get('application_id')).first()
if application_access_token is None:
raise AppApiException(500, '不存在的应用认证信息')
if not application_access_token.show_source:
raise AppApiException(500, '未开启显示知识来源')
def get_chat_record(self):
chat_record_id = self.data.get('chat_record_id')
chat_id = self.data.get('chat_id')
chat_info: ChatInfo = chat_cache.get(chat_id)
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
chat_record.id == uuid.UUID(chat_record_id)]
if chat_record_list is not None and len(chat_record_list):
return chat_record_list[-1]
if chat_info is not None:
chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if
chat_record.id == uuid.UUID(chat_record_id)]
if chat_record_list is not None and len(chat_record_list):
return chat_record_list[-1]
return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
def one(self, with_valid=True):

View File

@ -132,6 +132,8 @@ class ApplicationApi(ApiMixin):
'white_list': openapi.Schema(type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING), title="白名单列表",
description="白名单列表"),
'show_source': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否显示知识来源",
description="是否显示知识来源"),
}
)

View File

@ -172,7 +172,8 @@ class ChatView(APIView):
tags=["应用/对话日志"]
)
@has_permissions(
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY],
ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY,
RoleConstants.APPLICATION_ACCESS_TOKEN],
[lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE,
dynamic_tag=keywords.get('application_id'))])
)