From 128dc0a20184154031d56cf70059f6ad5793b790 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 23 Nov 2023 16:11:57 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BE=9B=E5=BA=94=E5=95=86=E5=88=97?= =?UTF-8?q?=E8=A1=A8=E6=8A=A5=E9=94=99,=E6=A8=A1=E5=9E=8B=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=A0=B9=E6=8D=AE=E4=BE=9B=E5=BA=94=E5=95=86=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../impl/azure_model_provider/azure_model_provider.py | 6 ++---- apps/setting/serializers/provider_serializers.py | 5 +++++ apps/setting/swagger_api/provide_api.py | 7 ++++++- apps/setting/views/model.py | 1 + 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py index 69f6f4a37..1b1822504 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py +++ b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py @@ -7,12 +7,10 @@ @desc: """ import os -from typing import Dict, List +from typing import Dict from langchain.chat_models import AzureChatOpenAI -from langchain.chat_models.base import BaseChatModel -from langchain.schema import HumanMessage, BaseMessage -from langchain.schema.language_model import LanguageModelInput +from langchain.schema import HumanMessage from common import froms from common.exception.app_exception import AppApiException diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index 0e68a3aa3..fba0b0a66 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -29,6 +29,8 @@ class ModelSerializer(serializers.Serializer): model_name = serializers.CharField(required=False) + provider = serializers.CharField(required=False) + def list(self, with_valid): if with_valid: self.is_valid(raise_exception=True) @@ -42,6 +44,9 @@ class ModelSerializer(serializers.Serializer): query_params['model_type'] = self.data.get('model_type') if self.data.get('model_name') is not None: query_params['model_name'] = self.data.get('model_name') + if self.data.get('provider') is not None: + query_params['provider'] = self.data.get('provider') + return [ModelSerializer.model_to_dict(model) for model in model_query_set.filter(**query_params)] class Create(serializers.Serializer): diff --git a/apps/setting/swagger_api/provide_api.py b/apps/setting/swagger_api/provide_api.py index 38076cfeb..083d37c8b 100644 --- a/apps/setting/swagger_api/provide_api.py +++ b/apps/setting/swagger_api/provide_api.py @@ -26,7 +26,12 @@ class ModelQueryApi(ApiMixin): openapi.Parameter(name='model_name', in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False, - description='基础模型名称') + description='基础模型名称'), + openapi.Parameter(name='provider', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='供应名称') ] diff --git a/apps/setting/views/model.py b/apps/setting/views/model.py index 84f32e673..0efbc3aaa 100644 --- a/apps/setting/views/model.py +++ b/apps/setting/views/model.py @@ -48,6 +48,7 @@ class Model(APIView): class Provide(APIView): + authentication_classes = [TokenAuth] class Exec(APIView): authentication_classes = [TokenAuth]