MaxKB/apps/local_model/serializers/model_apply_serializers.py

161 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: MaxKB
@Author
@file model_apply_serializers.py
@date2024/8/20 20:39
@desc:
"""
import json
import threading
import time
from django.db import connection
from django.db.models import QuerySet
from django.utils.translation import gettext_lazy as _
from langchain_core.documents import Document
from rest_framework import serializers
from local_model.models import Model
from local_model.serializers.rsa_util import rsa_long_decrypt
from models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider
from common.cache.mem_cache import MemCache
_lock = threading.Lock()
locks = {}
class ModelManage:
cache = MemCache('model', {})
up_clear_time = time.time()
@staticmethod
def _get_lock(_id):
lock = locks.get(_id)
if lock is None:
with _lock:
lock = locks.get(_id)
if lock is None:
lock = threading.Lock()
locks[_id] = lock
return lock
@staticmethod
def get_model(_id, get_model):
model_instance = ModelManage.cache.get(_id)
if model_instance is None:
lock = ModelManage._get_lock(_id)
with lock:
model_instance = ModelManage.cache.get(_id)
if model_instance is None:
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
else:
if model_instance.is_cache_model():
ModelManage.cache.touch(_id, timeout=60 * 60 * 8)
else:
model_instance = get_model(_id)
ModelManage.cache.set(_id, model_instance, timeout=60 * 60 * 8)
ModelManage.clear_timeout_cache()
return model_instance
@staticmethod
def clear_timeout_cache():
if time.time() - ModelManage.up_clear_time > 60 * 60:
threading.Thread(target=lambda: ModelManage.cache.clear_timeout_data()).start()
ModelManage.up_clear_time = time.time()
@staticmethod
def delete_key(_id):
if ModelManage.cache.has_key(_id):
ModelManage.cache.delete(_id)
def get_local_model(model, **kwargs):
return LocalModelProvider().get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
model_id=model.id,
streaming=True, **kwargs)
def get_embedding_model(model_id):
model = QuerySet(Model).filter(id=model_id).first()
# 手动关闭数据库连接
connection.close()
embedding_model = ModelManage.get_model(model_id,
lambda _id: get_local_model(model, use_local=True))
return embedding_model
class EmbedDocuments(serializers.Serializer):
texts = serializers.ListField(required=True, child=serializers.CharField(required=True,
label=_('vector text')),
label=_('vector text list')),
class EmbedQuery(serializers.Serializer):
text = serializers.CharField(required=True, label=_('vector text'))
class CompressDocument(serializers.Serializer):
page_content = serializers.CharField(required=True, label=_('text'))
metadata = serializers.DictField(required=False, label=_('metadata'))
class CompressDocuments(serializers.Serializer):
documents = CompressDocument(required=True, many=True)
query = serializers.CharField(required=True, label=_('query'))
class ValidateModelSerializers(serializers.Serializer):
model_name = serializers.CharField(required=True, label=_('model_name'))
model_type = serializers.CharField(required=True, label=_('model_type'))
model_credential = serializers.DictField(required=True, label="credential")
def validate_model(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'),
self.data.get('model_credential'), model_params={},
raise_exception=True)
class ModelApplySerializers(serializers.Serializer):
model_id = serializers.UUIDField(required=True, label=_('model id'))
def embed_documents(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
EmbedDocuments(data=instance).is_valid(raise_exception=True)
model = get_embedding_model(self.data.get('model_id'))
return model.embed_documents(instance.getlist('texts'))
def embed_query(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
EmbedQuery(data=instance).is_valid(raise_exception=True)
model = get_embedding_model(self.data.get('model_id'))
return model.embed_query(instance.get('text'))
def compress_documents(self, instance, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
CompressDocuments(data=instance).is_valid(raise_exception=True)
model = get_embedding_model(self.data.get('model_id'))
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
instance.get('documents')], instance.get('query'))]
def unload(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ModelManage.delete_key(self.data.get('model_id'))
return True