feat: 支持向量模型

This commit is contained in:
shaohuzhang1 2024-07-17 17:01:57 +08:00
parent 9b81b89975
commit bd4303aee7
14 changed files with 223 additions and 98 deletions

View File

@ -13,22 +13,45 @@ from django.db.models import QuerySet
from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel
from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.config.embedding_config import VectorStore, EmbeddingModelManage
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Paragraph
from dataset.models import Paragraph, DataSet
from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
return model
def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return dataset_list[0].embedding_mode_id
class BaseSearchDatasetStep(ISearchDatasetStep):
def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None,
search_mode: str = None,
**kwargs) -> List[ParagraphPipelineModel]:
if len(dataset_id_list) == 0:
return []
exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text
embedding_model = EmbeddingModel.get_embedding_model()
model_id = get_embedding_id(dataset_id_list)
model = get_model_by_id(model_id)
self.context['model_name'] = model.name
embedding_model = EmbeddingModelManage.get_model(model_id, lambda _id: get_model(model))
embedding_value = embedding_model.embed_query(exec_problem_text)
vector = VectorStore.get_embedding_vector()
embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list,
@ -101,7 +124,7 @@ class BaseSearchDatasetStep(ISearchDatasetStep):
'run_time': self.context['run_time'],
'problem_text': step_args.get(
'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'),
'model_name': EmbeddingModel.get_embedding_model().model_name,
'model_name': self.context.get('model_name'),
'message_tokens': 0,
'answer_tokens': 0,
'cost': 0

View File

@ -13,20 +13,47 @@ from django.db.models import QuerySet
from application.flow.i_step_node import NodeResult
from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode
from common.config.embedding_config import EmbeddingModel, VectorStore
from common.config.embedding_config import VectorStore, EmbeddingModelManage
from common.db.search import native_search
from common.util.file_util import get_file_content
from dataset.models import Document, Paragraph
from dataset.models import Document, Paragraph, DataSet
from embedding.models import SearchMode
from setting.models import Model
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR
def get_model_by_id(_id):
model = QuerySet(Model).filter(id=_id).first()
if model is None:
raise Exception("模型不存在")
return get_model(model)
def get_embedding_id(dataset_id_list):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return dataset_list[0].embedding_mode_id
def get_none_result(question):
return NodeResult(
{'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '',
'directly_return': ''}, {})
class BaseSearchDatasetNode(ISearchDatasetStepNode):
def execute(self, dataset_id_list, dataset_setting, question,
exclude_paragraph_id_list=None,
**kwargs) -> NodeResult:
self.context['question'] = question
embedding_model = EmbeddingModel.get_embedding_model()
if len(dataset_id_list) == 0:
return get_none_result(question)
model_id = get_embedding_id(dataset_id_list)
embedding_model = EmbeddingModelManage.get_model(model_id, get_model_by_id)
embedding_value = embedding_model.embed_query(question)
vector = VectorStore.get_embedding_vector()
exclude_document_id_list = [str(document.id) for document in
@ -37,7 +64,7 @@ class BaseSearchDatasetNode(ISearchDatasetStepNode):
exclude_paragraph_id_list, True, dataset_setting.get('top_n'),
dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode')))
if embedding_list is None:
return NodeResult({'paragraph_list': [], 'is_hit_handling_method': []}, {})
return get_none_result(question)
paragraph_list = self.list_paragraph(embedding_list, vector)
result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list]
return NodeResult({'paragraph_list': result,

View File

@ -26,7 +26,7 @@ from rest_framework import serializers
from application.flow.workflow_manage import Flow
from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion
from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.config.embedding_config import VectorStore
from common.constants.authentication_type import AuthenticationType
from common.db.search import get_dynamics_model, native_search, native_page_search
from common.db.sql_execute import select_list
@ -36,7 +36,7 @@ from common.util.common import valid_license
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from dataset.models import DataSet, Document, Image
from dataset.serializers.common_serializers import list_paragraph
from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list
from embedding.models import SearchMode
from setting.models import AuthOperate
from setting.models.model_management import Model
@ -415,12 +415,13 @@ class ApplicationSerializer(serializers.Serializer):
QuerySet(Document).filter(
dataset_id__in=dataset_id_list,
is_active=False)]
model = get_embedding_model_by_dataset_id_list(dataset_id_list)
# 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list,
self.data.get('top_number'),
self.data.get('similarity'),
SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model())
model)
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),

View File

@ -41,3 +41,7 @@ class MemCache(LocMemCache):
delete_keys.append(key)
for key in delete_keys:
self._delete(key)
def clear_timeout_data(self):
for key in self._cache.keys():
self.get(key)

View File

@ -6,33 +6,31 @@
@date2023/10/23 16:03
@desc:
"""
from langchain_huggingface.embeddings import HuggingFaceEmbeddings
import time
from smartdoc.const import CONFIG
from common.cache.mem_cache import MemCache
class EmbeddingModel:
instance = None
class EmbeddingModelManage:
cache = MemCache('model', {})
up_clear_time = time.time()
@staticmethod
def get_embedding_model():
"""
获取向量化模型
:return:
"""
if EmbeddingModel.instance is None:
model_name = CONFIG.get('EMBEDDING_MODEL_NAME')
cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH')
device = CONFIG.get('EMBEDDING_DEVICE')
encode_kwargs = {'normalize_embeddings': True}
e = HuggingFaceEmbeddings(
model_name=model_name,
cache_folder=cache_folder,
model_kwargs={'device': device},
encode_kwargs=encode_kwargs,
)
EmbeddingModel.instance = e
return EmbeddingModel.instance
def get_model(_id, get_model):
model_instance = EmbeddingModelManage.cache.get(_id)
if model_instance is None:
model_instance = get_model(_id)
EmbeddingModelManage.cache.set(_id, model_instance, timeout=60 * 30)
return model_instance
# 续期
EmbeddingModelManage.cache.touch(_id, timeout=60 * 30)
EmbeddingModelManage.clear_timeout_cache()
return model_instance
@staticmethod
def clear_timeout_cache():
if time.time() - EmbeddingModelManage.up_clear_time > 60:
EmbeddingModelManage.cache.clear_timeout_data()
class VectorStore:

View File

@ -15,8 +15,9 @@ from typing import List
import django.db.models
from blinker import signal
from django.db.models import QuerySet
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.config.embedding_config import VectorStore
from common.db.search import native_search, get_dynamics_model
from common.event.common import poxy, embedding_poxy
from common.util.file_util import get_file_content
@ -89,11 +90,11 @@ class ListenerManagement:
@staticmethod
@embedding_poxy
def embedding_by_paragraph(paragraph_id):
def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings):
"""
向量化段落 根据段落id
:param paragraph_id: 段落id
:return: None
@param paragraph_id: 段落id
@param embedding_model: 向量模型
"""
max_kb.info(f"开始--->向量化段落:{paragraph_id}")
status = Status.success
@ -107,7 +108,7 @@ class ListenerManagement:
# 删除段落
VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id)
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
except Exception as e:
max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
@ -117,10 +118,11 @@ class ListenerManagement:
@staticmethod
@embedding_poxy
def embedding_by_document(document_id):
def embedding_by_document(document_id, embedding_model: Embeddings):
"""
向量化文档
:param document_id: 文档id
@param document_id: 文档id
@param embedding_model 向量模型
:return: None
"""
max_kb.info(f"开始--->向量化文档:{document_id}")
@ -138,7 +140,7 @@ class ListenerManagement:
# 删除文档向量数据
VectorStore.get_embedding_vector().delete_by_document_id(document_id)
# 批量向量化
VectorStore.get_embedding_vector().batch_save(data_list)
VectorStore.get_embedding_vector().batch_save(data_list, embedding_model)
except Exception as e:
max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}')
status = Status.error
@ -151,10 +153,11 @@ class ListenerManagement:
@staticmethod
@embedding_poxy
def embedding_by_dataset(dataset_id):
def embedding_by_dataset(dataset_id, embedding_model: Embeddings):
"""
向量化知识库
:param dataset_id: 知识库id
@param dataset_id: 知识库id
@param embedding_model 向量模型
:return: None
"""
max_kb.info(f"开始--->向量化数据集:{dataset_id}")
@ -162,7 +165,7 @@ class ListenerManagement:
document_list = QuerySet(Document).filter(dataset_id=dataset_id)
max_kb.info(f"数据集文档:{[d.name for d in document_list]}")
for document in document_list:
ListenerManagement.embedding_by_document(document.id)
ListenerManagement.embedding_by_document(document.id, embedding_model)
except Exception as e:
max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}')
finally:
@ -245,11 +248,6 @@ class ListenerManagement:
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
@staticmethod
@poxy
def init_embedding_model(ages):
EmbeddingModel.get_embedding_model()
def run(self):
# 添加向量 根据问题id
ListenerManagement.embedding_by_problem_signal.connect(self.embedding_by_problem)
@ -276,8 +274,7 @@ class ListenerManagement:
ListenerManagement.disable_embedding_by_paragraph_signal.connect(self.disable_embedding_by_paragraph)
# 启动段落向量
ListenerManagement.enable_embedding_by_paragraph_signal.connect(self.enable_embedding_by_paragraph)
# 初始化向量化模型
ListenerManagement.init_embedding_model_signal.connect(self.init_embedding_model)
# 同步web站点知识库
ListenerManagement.sync_web_dataset_signal.connect(self.sync_web_dataset)
# 同步web站点 文档

View File

@ -1,4 +1,4 @@
# Generated by Django 4.2.13 on 2024-07-15 15:56
# Generated by Django 4.2.13 on 2024-07-17 13:56
import dataset.models.data_set
from django.db import migrations, models
@ -15,7 +15,7 @@ class Migration(migrations.Migration):
operations = [
migrations.AddField(
model_name='dataset',
name='embedding_mode_id',
name='embedding_mode',
field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'),
),
]

View File

@ -49,8 +49,8 @@ class DataSet(AppModelMixin):
user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户")
type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices,
default=Type.base)
embedding_mode_id = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
default=default_model)
embedding_mode = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型",
default=default_model)
meta = models.JSONField(verbose_name="元数据", default=dict)
class Meta:

View File

@ -14,6 +14,7 @@ from django.db.models import QuerySet
from drf_yasg import openapi
from rest_framework import serializers
from common.config.embedding_config import EmbeddingModelManage
from common.db.search import native_search
from common.db.sql_execute import update_execute
from common.exception.app_exception import AppApiException
@ -21,7 +22,8 @@ from common.mixins.api_mixin import ApiMixin
from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.fork import Fork
from dataset.models import Paragraph, Problem, ProblemParagraphMapping
from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet
from setting.models_provider import get_model
from smartdoc.conf import PROJECT_DIR
@ -130,3 +132,18 @@ class ProblemParagraphManage:
result = [problem_model for problem_model, is_create in problem_content_dict.values() if
is_create], problem_paragraph_mapping_list
return result
def get_embedding_model_by_dataset_id_list(dataset_id_list: List):
dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list)
if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1:
raise Exception("知识库未向量模型不一致")
if len(dataset_list) == 0:
raise Exception("知识库设置错误,请重新设置知识库")
return EmbeddingModelManage.get_model(str(dataset_list[0].id),
lambda _id: get_model(dataset_list[0].embedding_mode))
def get_embedding_model_by_dataset_id(dataset_id: str):
dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id)
return EmbeddingModelManage.get_model(str(dataset.id), lambda _id: get_model(dataset.embedding_mode))

View File

@ -15,7 +15,6 @@ from functools import reduce
from typing import Dict, List
from urllib.parse import urlparse
from django.conf import settings
from django.contrib.postgres.fields import ArrayField
from django.core import validators
from django.db import transaction, models
@ -25,7 +24,7 @@ from drf_yasg import openapi
from rest_framework import serializers
from application.models import ApplicationDatasetMapping
from common.config.embedding_config import VectorStore, EmbeddingModel
from common.config.embedding_config import VectorStore
from common.db.search import get_dynamics_model, native_page_search, native_search
from common.db.sql_execute import select_list
from common.event import ListenerManagement, SyncWebDatasetArgs
@ -37,7 +36,8 @@ from common.util.file_util import get_file_content
from common.util.fork import ChildLink, Fork
from common.util.split_model import get_split_model
from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage
from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \
get_embedding_model_by_dataset_id
from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer
from embedding.models import SearchMode
from setting.models import AuthOperate
@ -359,8 +359,9 @@ class DataSetSerializers(serializers.ModelSerializer):
@staticmethod
def post_embedding_dataset(document_list, dataset_id):
model = get_embedding_model_by_dataset_id(dataset_id)
# 发送向量化事件
ListenerManagement.embedding_by_dataset_signal.send(dataset_id)
ListenerManagement.embedding_by_dataset_signal.send(dataset_id, embedding_model=model)
return document_list
def save_qa(self, instance: Dict, with_valid=True):
@ -565,12 +566,13 @@ class DataSetSerializers(serializers.ModelSerializer):
QuerySet(Document).filter(
dataset_id=self.data.get('id'),
is_active=False)]
model = get_embedding_model_by_dataset_id(self.data.get('id'))
# 向量库检索
hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list,
self.data.get('top_number'),
self.data.get('similarity'),
SearchMode(self.data.get('search_mode')),
EmbeddingModel.get_embedding_model())
model)
hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {})
p_list = list_paragraph([h.get('paragraph_id') for h in hit_list])
return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'),

View File

@ -10,9 +10,8 @@ import threading
from abc import ABC, abstractmethod
from typing import List, Dict
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import EmbeddingModel
from common.util.common import sub_array
from embedding.models import SourceType, SearchMode
@ -51,7 +50,7 @@ class BaseVectorStore(ABC):
def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding=None):
embedding: Embeddings):
"""
插入向量数据
:param source_id: 资源id
@ -64,13 +63,10 @@ class BaseVectorStore(ABC):
:param paragraph_id 段落id
:return: bool
"""
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler()
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
def batch_save(self, data_list: List[Dict], embedding=None):
def batch_save(self, data_list: List[Dict], embedding: Embeddings):
# 获取锁
lock.acquire()
try:
@ -80,8 +76,6 @@ class BaseVectorStore(ABC):
:param embedding: 向量化处理器
:return: bool
"""
if embedding is None:
embedding = EmbeddingModel.get_embedding_model()
self.save_pre_handler()
result = sub_array(data_list)
for child_array in result:
@ -94,17 +88,17 @@ class BaseVectorStore(ABC):
@abstractmethod
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: HuggingFaceEmbeddings):
embedding: Embeddings):
pass
@abstractmethod
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
pass
def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str],
exclude_paragraph_list: list[str],
is_active: bool,
embedding: HuggingFaceEmbeddings):
embedding: Embeddings):
if dataset_id_list is None or len(dataset_id_list) == 0:
return []
embedding_query = embedding.embed_query(query_text)
@ -123,7 +117,7 @@ class BaseVectorStore(ABC):
def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: HuggingFaceEmbeddings):
embedding: Embeddings):
pass
@abstractmethod
@ -142,14 +136,6 @@ class BaseVectorStore(ABC):
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
pass
@abstractmethod
def embed_documents(self, text_list: List[str]):
pass
@abstractmethod
def embed_query(self, text: str):
pass
@abstractmethod
def delete_by_dataset_id(self, dataset_id: str):
pass

View File

@ -13,9 +13,8 @@ from abc import ABC, abstractmethod
from typing import Dict, List
from django.db.models import QuerySet
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_core.embeddings import Embeddings
from common.config.embedding_config import EmbeddingModel
from common.db.search import generate_sql_by_query_dict
from common.db.sql_execute import select_list
from common.util.file_util import get_file_content
@ -33,14 +32,6 @@ class PGVector(BaseVectorStore):
def update_by_source_ids(self, source_ids: List[str], instance: Dict):
QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance)
def embed_documents(self, text_list: List[str]):
embedding = EmbeddingModel.get_embedding_model()
return embedding.embed_documents(text_list)
def embed_query(self, text: str):
embedding = EmbeddingModel.get_embedding_model()
return embedding.embed_query(text)
def vector_is_create(self) -> bool:
# 项目启动默认是创建好的 不需要再创建
return True
@ -50,7 +41,7 @@ class PGVector(BaseVectorStore):
def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str,
is_active: bool,
embedding: HuggingFaceEmbeddings):
embedding: Embeddings):
text_embedding = embedding.embed_query(text)
embedding = Embedding(id=uuid.uuid1(),
dataset_id=dataset_id,
@ -64,7 +55,7 @@ class PGVector(BaseVectorStore):
embedding.save()
return True
def _batch_save(self, text_list: List[Dict], embedding: HuggingFaceEmbeddings):
def _batch_save(self, text_list: List[Dict], embedding: Embeddings):
texts = [row.get('text') for row in text_list]
embeddings = embedding.embed_documents(texts)
embedding_list = [Embedding(id=uuid.uuid1(),
@ -83,7 +74,7 @@ class PGVector(BaseVectorStore):
def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int,
similarity: float,
search_mode: SearchMode,
embedding: HuggingFaceEmbeddings):
embedding: Embeddings):
if dataset_id_list is None or len(dataset_id_list) == 0:
return []
exclude_dict = {}

View File

@ -6,3 +6,85 @@
@date2023/10/31 17:16
@desc:
"""
import json
from typing import Dict
from common.util.rsa_util import rsa_long_decrypt
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
def get_model_(provider, model_type, model_name, credential):
"""
获取模型实例
@param provider: 供应商
@param model_type: 模型类型
@param model_name: 模型名称
@param credential: 认证信息
@return: 模型实例
"""
model = get_provider(provider).get_model(model_type, model_name,
json.loads(
rsa_long_decrypt(credential)),
streaming=True)
return model
def get_model(model):
"""
获取模型实例
@param model: model 数据库Model实例对象
@return: 模型实例
"""
return get_model_(model.provider, model.model_type, model.model_name, model.credential)
def get_provider(provider):
"""
获取供应商实例
@param provider: 供应商字符串
@return: 供应商实例
"""
return ModelProvideConstants[provider].value
def get_model_list(provider, model_type):
"""
获取模型列表
@param provider: 供应商字符串
@param model_type: 模型类型
@return: 模型列表
"""
return get_provider(provider).get_model_list(model_type)
def get_model_credential(provider, model_type, model_name):
"""
获取模型认证实例
@param provider: 供应商字符串
@param model_type: 模型类型
@param model_name: 模型名称
@return: 认证实例对象
"""
return get_provider(provider).get_model_credential(model_type, model_name)
def get_model_type_list(provider):
"""
获取模型类型列表
@param provider: 供应商字符串
@return: 模型类型列表
"""
return get_provider(provider).get_model_type_list()
def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False):
"""
校验模型认证参数
@param provider: 供应商字符串
@param model_type: 模型类型
@param model_name: 模型名称
@param model_credential: 模型认证数据
@param raise_exception: 是否抛出错误
@return: True|False
"""
return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception)

View File

@ -104,9 +104,6 @@ CACHES = {
"token_cache": {
'BACKEND': 'common.cache.file_cache.FileCache',
'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径
},
"chat_cache": {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
}
}