MaxKB/apps/application/serializers/chat_serializers.py
2023-11-27 14:22:12 +08:00

284 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file chat_serializers.py
@date2023/11/14 9:59
@desc:
"""
import datetime
import json
import os
import uuid
from typing import Dict
from django.core.cache import cache
from django.db import transaction
from django.db.models import QuerySet
from rest_framework import serializers
from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord
from application.serializers.application_serializers import ModelDatasetAssociation
from application.serializers.chat_message_serializers import ChatInfo
from common.db.search import native_search, native_page_search, page_search
from common.exception.app_exception import AppApiException
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import decrypt
from dataset.models import Document, Problem, Paragraph
from embedding.models import SourceType, Embedding
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR
chat_cache = cache
class ChatSerializers(serializers.Serializer):
class Query(serializers.Serializer):
abstract = serializers.CharField(required=False)
history_day = serializers.IntegerField(required=True)
user_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True)
def get_end_time(self):
history_day = self.data.get('history_day')
return datetime.datetime.now() - datetime.timedelta(days=history_day)
def get_query_set(self):
end_time = self.get_end_time()
return QuerySet(Chat).filter(application_id=self.data.get("application_id"), create_time__gte=end_time)
def list(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return native_search(self.get_query_set(), select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')),
with_table_name=True)
def page(self, current_page: int, page_size: int, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')),
with_table_name=True)
class OpenChat(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
application_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
user_id = self.data.get('user_id')
application_id = self.data.get('application_id')
if not QuerySet(Application).filter(id=application_id, user_id=user_id).exists():
raise AppApiException(500, '应用不存在')
def open(self):
self.is_valid(raise_exception=True)
application_id = self.data.get('application_id')
application = QuerySet(Application).get(id=application_id)
model = application.model
dataset_id_list = [str(row.dataset_id) for row in
QuerySet(ApplicationDatasetMapping).filter(
application_id=application_id)]
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
chat_id = str(uuid.uuid1())
chat_cache.set(chat_id,
ChatInfo(chat_id, model, chat_model, application_id, dataset_id_list,
[str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)],
application.dialogue_number), timeout=60 * 30)
return chat_id
class OpenTempChat(serializers.Serializer):
user_id = serializers.UUIDField(required=True)
model_id = serializers.UUIDField(required=True)
multiple_rounds_dialogue = serializers.BooleanField(required=True)
dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True))
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
ModelDatasetAssociation(
data={'user_id': self.data.get('user_id'), 'model_id': self.data.get('model_id'),
'dataset_id_list': self.data.get('dataset_id_list')}).is_valid()
def open(self):
self.is_valid(raise_exception=True)
chat_id = str(uuid.uuid1())
model = QuerySet(Model).get(user_id=self.data.get('user_id'), id=self.data.get('model_id'))
dataset_id_list = self.data.get('dataset_id_list')
chat_model = ModelProvideConstants[model.provider].value.get_model(model.model_type, model.model_name,
json.loads(
decrypt(model.credential)),
streaming=True)
chat_cache.set(chat_id,
ChatInfo(chat_id, model, chat_model, None, dataset_id_list,
[str(document.id) for document in
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)],
3 if self.data.get('multiple_rounds_dialogue') else 1), timeout=60 * 30)
return chat_id
def vote_exec(source_type: SourceType, source_id: str, field: str, post_handler):
if source_type == SourceType.PROBLEM:
problem = QuerySet(Problem).get(id=source_id)
if problem is not None:
problem.__setattr__(field, post_handler(problem))
problem.save()
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
embedding.__setattr__(field, problem.__getattribute__(field))
embedding.save()
if source_type == SourceType.PARAGRAPH:
paragraph = QuerySet(Paragraph).get(id=source_id)
if paragraph is not None:
paragraph.__setattr__(field, post_handler(paragraph))
paragraph.save()
embedding = QuerySet(Embedding).get(source_id=source_id, source_type=source_type)
embedding.__setattr__(field, paragraph.__getattribute__(field))
embedding.save()
class ChatRecordSerializerModel(serializers.ModelSerializer):
class Meta:
model = ChatRecord
fields = "__all__"
class ChatRecordSerializer(serializers.Serializer):
class Query(serializers.Serializer):
application_id = serializers.UUIDField(required=True)
chat_id = serializers.UUIDField(required=True)
def list(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return [ChatRecordSerializerModel(chat_record).data for chat_record in
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id'))]
def page(self, current_page: int, page_size: int, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
return page_search(current_page, page_size,
QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by("index"),
post_records_handler=lambda chat_record: ChatRecordSerializerModel(chat_record).data)
class Vote(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_record_id = serializers.UUIDField(required=True)
vote_status = serializers.ChoiceField(choices=VoteChoices.choices)
@transaction.atomic
def vote(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
if not try_lock(self.data.get('chat_record_id')):
raise AppApiException(500, "正在对当前会话纪要进行投票中,请勿重复发送请求")
try:
chat_record_details_model = QuerySet(ChatRecord).get(id=self.data.get('chat_record_id'),
chat_id=self.data.get('chat_id'))
if chat_record_details_model is None:
raise AppApiException(500, "不存在的对话 chat_record_id")
vote_status = self.data.get("vote_status")
if chat_record_details_model.vote_status == VoteChoices.UN_VOTE:
if vote_status == VoteChoices.STAR:
# 点赞
chat_record_details_model.vote_status = VoteChoices.STAR
# 点赞数量 +1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'star_num',
lambda r: r.star_num + 1)
if vote_status == VoteChoices.TRAMPLE:
# 点踩
chat_record_details_model.vote_status = VoteChoices.TRAMPLE
# 点踩数量+1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'trample_num',
lambda r: r.trample_num + 1)
chat_record_details_model.save()
else:
if vote_status == VoteChoices.UN_VOTE:
# 取消点赞
chat_record_details_model.vote_status = VoteChoices.UN_VOTE
chat_record_details_model.save()
if chat_record_details_model.vote_status == VoteChoices.STAR:
# 点赞数量 -1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'star_num', lambda r: r.star_num - 1)
if chat_record_details_model.vote_status == VoteChoices.TRAMPLE:
# 点踩数量 -1
vote_exec(chat_record_details_model.source_type, chat_record_details_model.source_id,
'trample_num', lambda r: r.trample_num - 1)
else:
raise AppApiException(500, "已经投票过,请先取消后再进行投票")
finally:
un_lock(self.data.get('chat_record_id'))
return True
class ImproveSerializer(serializers.Serializer):
title = serializers.CharField(required=False)
content = serializers.CharField(required=True)
class Improve(serializers.Serializer):
chat_id = serializers.UUIDField(required=True)
chat_record_id = serializers.UUIDField(required=True)
dataset_id = serializers.UUIDField(required=True)
document_id = serializers.UUIDField(required=True)
def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
if not QuerySet(Document).filter(id=self.data.get('document_id'),
dataset_id=self.data.get('dataset_id')).exists():
raise AppApiException(500, "文档id不正确")
@transaction.atomic
def improve(self, instance: Dict, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ChatRecordSerializer.ImproveSerializer(data=instance).is_valid(raise_exception=True)
chat_record_id = self.data.get('chat_record_id')
chat_id = self.data.get('chat_id')
chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first()
if chat_record is None:
raise AppApiException(500, '不存在的对话记录')
document_id = self.data.get("document_id")
dataset_id = self.data.get("dataset_id")
paragraph = Paragraph(id=uuid.uuid1(),
document_id=document_id,
content=instance.get("content"),
dataset_id=dataset_id,
title=instance.get("title") if 'title' in instance else '')
problem = Problem(id=uuid.uuid1(), content=chat_record.problem_text, paragraph_id=paragraph.id,
document_id=document_id, dataset_id=dataset_id)
# 插入问题
problem.save()
# 插入段落
paragraph.save()
chat_record.improve_problem_id_list.append(problem.id)
# 添加标注
chat_record.save()
return True