diff --git a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py index 69f6f4a37..1b1822504 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py +++ b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py @@ -7,12 +7,10 @@ @desc: """ import os -from typing import Dict, List +from typing import Dict from langchain.chat_models import AzureChatOpenAI -from langchain.chat_models.base import BaseChatModel -from langchain.schema import HumanMessage, BaseMessage -from langchain.schema.language_model import LanguageModelInput +from langchain.schema import HumanMessage from common import froms from common.exception.app_exception import AppApiException diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index 0e68a3aa3..fba0b0a66 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -29,6 +29,8 @@ class ModelSerializer(serializers.Serializer): model_name = serializers.CharField(required=False) + provider = serializers.CharField(required=False) + def list(self, with_valid): if with_valid: self.is_valid(raise_exception=True) @@ -42,6 +44,9 @@ class ModelSerializer(serializers.Serializer): query_params['model_type'] = self.data.get('model_type') if self.data.get('model_name') is not None: query_params['model_name'] = self.data.get('model_name') + if self.data.get('provider') is not None: + query_params['provider'] = self.data.get('provider') + return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)] class Create(serializers.Serializer): diff --git a/apps/setting/swagger_api/provide_api.py b/apps/setting/swagger_api/provide_api.py index 38076cfeb..083d37c8b 100644 --- a/apps/setting/swagger_api/provide_api.py +++ b/apps/setting/swagger_api/provide_api.py @@ -26,7 +26,12 @@ class ModelQueryApi(ApiMixin): openapi.Parameter(name='model_name', in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False, - description='基础模型名称') + description='基础模型名称'), + openapi.Parameter(name='provider', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='供应名称') ] diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py index 84f32e673..0efbc3aaa 100644 --- a/apps/setting/views/model.py +++ b/apps/setting/views/model.py @@ -48,6 +48,7 @@ class Model(APIView): class Provide(APIView): + authentication_classes = [TokenAuth] class Exec(APIView): authentication_classes = [TokenAuth]