diff --git a/apps/setting/models_provider/__init__.py b/apps/setting/models_provider/__init__.py index 7f573ec5e..fb278630a 100644 --- a/apps/setting/models_provider/__init__.py +++ b/apps/setting/models_provider/__init__.py @@ -81,7 +81,7 @@ def get_model_type_list(provider): return get_provider(provider).get_model_type_list() -def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False): +def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], model_params, raise_exception=False): """ 校验模型认证参数 @param provider: 供应商字符串 @@ -91,4 +91,4 @@ def is_valid_credential(provider, model_type, model_name, model_credential: Dict @param raise_exception: 是否抛出错误 @return: True|False """ - return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception) + return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, model_params, raise_exception) diff --git a/apps/setting/models_provider/base_model_provider.py b/apps/setting/models_provider/base_model_provider.py index 39a759a65..35c3ef029 100644 --- a/apps/setting/models_provider/base_model_provider.py +++ b/apps/setting/models_provider/base_model_provider.py @@ -67,9 +67,13 @@ class IModelProvider(ABC): model_info = self.get_model_info_manage().get_model_info(model_type, model_name) return model_info.model_credential - def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], raise_exception=False): + def get_model_params(self, model_type, model_name): model_info = self.get_model_info_manage().get_model_info(model_type, model_name) - return model_info.model_credential.is_valid(model_type, model_name, model_credential, self, + return model_info.model_credential + + def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], model_params: Dict[str, object], raise_exception=False): + model_info = self.get_model_info_manage().get_model_info(model_type, model_name) + return model_info.model_credential.is_valid(model_type, model_name, model_credential, model_params, self, raise_exception=raise_exception) def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel: @@ -105,7 +109,7 @@ class MaxKBBaseModel(ABC): class BaseModelCredential(ABC): @abstractmethod - def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True): + def is_valid(self, model_type: str, model_name, model: Dict[str, object], model_params, provider, raise_exception=True): pass @abstractmethod diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py index 7884e5142..09ba7a752 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py @@ -17,7 +17,7 @@ from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.embedding class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py index e77d4dfdf..d3205d8ab 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py @@ -37,7 +37,7 @@ class QwenModelParams(BaseForm): class QwenVLModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -49,7 +49,7 @@ class QwenVLModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) for chunk in res: print(chunk) diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py index 1c822adef..db5f156cc 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py @@ -28,7 +28,7 @@ class BaiLianLLMModelParams(BaseForm): class BaiLianLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -41,7 +41,7 @@ class BaiLianLLMModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py index d8d2f3cab..cd72274d3 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py @@ -19,7 +19,7 @@ from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): if not model_type == 'RERANKER': raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py index 5c9290b15..286650f1c 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py @@ -11,7 +11,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField("API Key", required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py index 395d94db9..8fc39db7f 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py @@ -61,7 +61,7 @@ class QwenModelParams(BaseForm): class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -73,7 +73,7 @@ class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() print(res) except Exception as e: diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py index 640ba7a01..09a2bbe4b 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py @@ -45,7 +45,7 @@ class AliyunBaiLianTTSModelGeneralParams(BaseForm): class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField("API Key", required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -58,7 +58,7 @@ class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py index 520960d7a..766c505cf 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py @@ -26,7 +26,7 @@ class BedrockEmbeddingCredential(BaseForm, BaseModelCredential): with open(credentials_path, 'w') as file: file.write(content) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(mt.get('value') == model_type for mt in model_type_list): diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py index df18fc6ac..9d1cd4210 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py @@ -44,7 +44,7 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential): with open(credentials_path, 'w') as file: file.write(content) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(mt.get('value') == model_type for mt in model_type_list): @@ -62,7 +62,7 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential): self._update_aws_credentials('aws-profile', model_credential['access_key_id'], model_credential['secret_access_key']) model_credential['credentials_profile_name'] = 'aws-profile' - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) except AppApiException: raise diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/azure_model_provider/credential/embedding.py index baccfff52..12f26b53f 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/embedding.py @@ -19,7 +19,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/image.py b/apps/setting/models_provider/impl/azure_model_provider/credential/image.py index 0c2eeb77b..3c93d557d 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/image.py @@ -33,7 +33,7 @@ class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True) api_key = forms.PasswordInputField("API Key (api_key)", required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -46,7 +46,7 @@ class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) for chunk in res: print(chunk) diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py index 09e51dca6..a902551c8 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py @@ -35,7 +35,7 @@ class AzureLLMModelParams(BaseForm): class AzureLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -48,7 +48,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/stt.py b/apps/setting/models_provider/impl/azure_model_provider/credential/stt.py index 53fa46a44..e337f848b 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/credential/stt.py +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/stt.py @@ -12,7 +12,7 @@ class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True) api_key = forms.PasswordInputField("API Key (api_key)", required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/tti.py b/apps/setting/models_provider/impl/azure_model_provider/credential/tti.py index 2079ba741..227700351 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/tti.py @@ -51,7 +51,7 @@ class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True) api_key = forms.PasswordInputField("API Key (api_key)", required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -64,7 +64,7 @@ class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() print(res) except Exception as e: diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/tts.py b/apps/setting/models_provider/impl/azure_model_provider/credential/tts.py index c1fa4ec6d..9aed903d9 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/credential/tts.py +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/tts.py @@ -28,7 +28,7 @@ class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True) api_key = forms.PasswordInputField("API Key (api_key)", required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -41,7 +41,7 @@ class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py index 9739b71ac..3e861ec36 100644 --- a/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py @@ -35,7 +35,7 @@ class DeepSeekLLMModelParams(BaseForm): class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -48,7 +48,7 @@ class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/gemini_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/gemini_model_provider/credential/embedding.py index 5fc0a9c88..6649f5923 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/credential/embedding.py @@ -15,7 +15,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class GeminiEmbeddingCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=True): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/gemini_model_provider/credential/image.py b/apps/setting/models_provider/impl/gemini_model_provider/credential/image.py index 33cc60bbd..2c3fe7366 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/credential/image.py @@ -31,7 +31,7 @@ class GeminiImageModelParams(BaseForm): class GeminiImageModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -44,7 +44,7 @@ class GeminiImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) for chunk in res: print(chunk) diff --git a/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py index 4cacbe12f..c2f2cb780 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py @@ -35,7 +35,7 @@ class GeminiLLMModelParams(BaseForm): class GeminiLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -48,7 +48,7 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.invoke([HumanMessage(content='你好')]) print(res) except Exception as e: diff --git a/apps/setting/models_provider/impl/gemini_model_provider/credential/stt.py b/apps/setting/models_provider/impl/gemini_model_provider/credential/stt.py index 90e0164f0..cfa3aa79c 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/credential/stt.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/credential/stt.py @@ -10,7 +10,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class GeminiSTTModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py b/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py index a6d06a894..2feed9a03 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py @@ -35,7 +35,7 @@ class KimiLLMModelParams(BaseForm): class KimiLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -48,7 +48,7 @@ class KimiLLMModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py index a631196eb..ec899c25d 100644 --- a/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py @@ -17,7 +17,7 @@ from setting.models_provider.impl.local_model_provider.model.embedding import Lo class LocalEmbeddingCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): if not model_type == 'EMBEDDING': raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') diff --git a/apps/setting/models_provider/impl/local_model_provider/credential/reranker.py b/apps/setting/models_provider/impl/local_model_provider/credential/reranker.py index 0048fcedb..ee89a5c96 100644 --- a/apps/setting/models_provider/impl/local_model_provider/credential/reranker.py +++ b/apps/setting/models_provider/impl/local_model_provider/credential/reranker.py @@ -19,7 +19,7 @@ from setting.models_provider.impl.local_model_provider.model.reranker import Loc class LocalRerankerCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): if not model_type == 'RERANKER': raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py index e0eeabe59..235aa9633 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py @@ -16,7 +16,7 @@ from setting.models_provider.impl.local_model_provider.model.embedding import Lo class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/image.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/image.py index 35b366266..c285feefa 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/image.py @@ -32,7 +32,7 @@ class OllamaImageModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py index 33f6d8c26..ab0749fe3 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py @@ -32,7 +32,7 @@ class OllamaLLMModelParams(BaseForm): class OllamaLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py index d49d22e22..15c1b2add 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py @@ -15,7 +15,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=True): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/image.py b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py index 83c1e70d1..9196bbede 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/image.py @@ -32,7 +32,7 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -45,7 +45,7 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) for chunk in res: print(chunk) diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py b/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py index 755f9558a..8e606b754 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py @@ -35,7 +35,7 @@ class OpenAILLMModelParams(BaseForm): class OpenAILLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -48,7 +48,8 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py b/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py index 59506311c..e237a77bd 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py @@ -11,7 +11,7 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py b/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py index 668ebca8a..2e90ed7e2 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/tti.py @@ -50,7 +50,7 @@ class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -63,7 +63,7 @@ class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() print(res) except Exception as e: diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/tts.py b/apps/setting/models_provider/impl/openai_model_provider/credential/tts.py index 96d00131a..4f607a9e7 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/tts.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/tts.py @@ -27,7 +27,7 @@ class OpenAITTSModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -40,7 +40,7 @@ class OpenAITTSModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/image.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/image.py index e77d4dfdf..d3205d8ab 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/image.py @@ -37,7 +37,7 @@ class QwenModelParams(BaseForm): class QwenVLModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -49,7 +49,7 @@ class QwenVLModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) for chunk in res: print(chunk) diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py index b9bb45ea9..68745a175 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py @@ -35,7 +35,7 @@ class QwenModelParams(BaseForm): class OpenAILLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -47,7 +47,7 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py index 395d94db9..8fc39db7f 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/tti.py @@ -61,7 +61,7 @@ class QwenModelParams(BaseForm): class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -73,7 +73,7 @@ class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() print(res) except Exception as e: diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py index a0b00649d..9f4d7e58b 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py @@ -8,7 +8,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class TencentEmbeddingCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=True) -> bool: model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/image.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/image.py index ae534d41b..9f0634833 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/image.py @@ -37,7 +37,7 @@ class QwenModelParams(BaseForm): class TencentVisionModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -49,7 +49,7 @@ class TencentVisionModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) for chunk in res: print(chunk) diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py index 20b1bf824..8481a2c48 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py @@ -36,12 +36,12 @@ class TencentLLMModelCredential(BaseForm, BaseModelCredential): return False return True - def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False): + def is_valid(self, model_type, model_name, model_credential, provider, model_params, raise_exception=False): if not (self._validate_model_type(model_type, provider, raise_exception) and self._validate_credential_fields(model_credential, raise_exception)): return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) except Exception as e: if raise_exception: diff --git a/apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py b/apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py index 1b6183e8a..c0c8d583d 100644 --- a/apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/tencent_model_provider/credential/tti.py @@ -85,12 +85,12 @@ class TencentTTIModelCredential(BaseForm, BaseModelCredential): return False return True - def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False): + def is_valid(self, model_type, model_name, model_credential, model_params, provider, raise_exception=False): if not (self._validate_model_type(model_type, provider, raise_exception) and self._validate_credential_fields(model_credential, raise_exception)): return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: if raise_exception: diff --git a/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py index 6c14a45fa..5768d5f00 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py @@ -28,7 +28,7 @@ class VLLMModelParams(BaseForm): class VLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -40,7 +40,7 @@ class VLLMModelCredential(BaseForm, BaseModelCredential): exist = provider.get_model_info_by_name(model_list, model_name) if len(exist) == 0: raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型") - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) try: res = model.invoke([HumanMessage(content='你好')]) print(res) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py index d49d22e22..15c1b2add 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py @@ -15,7 +15,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=True): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/image.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/image.py index ff31b5ef0..08a85a970 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/image.py @@ -30,7 +30,7 @@ class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) api_base = forms.TextInputField('API 域名', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -43,7 +43,7 @@ class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) for chunk in res: print(chunk) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py index 48c434b4b..0fb352422 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py @@ -35,7 +35,7 @@ class VolcanicEngineLLMModelParams(BaseForm): class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -48,7 +48,7 @@ class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.invoke([HumanMessage(content='你好')]) print(res) except Exception as e: diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py index d7607ded0..37980306e 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py @@ -14,7 +14,7 @@ class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): volcanic_token = forms.PasswordInputField('Access Token', required=True) volcanic_cluster = forms.TextInputField('Cluster ID', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py index 9e6a7967d..98b78a1e9 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py @@ -31,7 +31,7 @@ class VolcanicEngineTTIModelCredential(BaseForm, BaseModelCredential): access_key = forms.PasswordInputField('Access Key ID', required=True) secret_key = forms.PasswordInputField('Secret Access Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -44,7 +44,7 @@ class VolcanicEngineTTIModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py index b565b162b..7f157b999 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py @@ -42,7 +42,7 @@ class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential): volcanic_token = forms.PasswordInputField('Access Token', required=True) volcanic_cluster = forms.TextInputField('Cluster ID', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -55,7 +55,7 @@ class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/wenxin_model_provider/credential/embedding.py index 25af4d5ab..9b6c780ba 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/credential/embedding.py @@ -16,7 +16,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class QianfanEmbeddingCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py index 342cb2e08..9a24becb1 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -34,12 +34,12 @@ class WenxinLLMModelParams(BaseForm): class WenxinLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model_info = [model.lower() for model in model.client.models()] if not model_info.__contains__(model_name.lower()): raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持') diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/xf_model_provider/credential/embedding.py index 63214bdc5..085a33065 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/embedding.py @@ -16,7 +16,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class XFEmbeddingCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/image.py b/apps/setting/models_provider/impl/xf_model_provider/credential/image.py index 449f32129..e0486adb2 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/image.py @@ -18,7 +18,7 @@ class XunFeiImageModelCredential(BaseForm, BaseModelCredential): spark_api_key = forms.PasswordInputField("API Key", required=True) spark_api_secret = forms.PasswordInputField('API Secret', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -31,7 +31,7 @@ class XunFeiImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) cwd = os.path.dirname(os.path.abspath(__file__)) with open(f'{cwd}/img_1.png', 'rb') as f: message_list = [ImageMessage(str(base64.b64encode(f.read()), 'utf-8')), HumanMessage('请概述这张图片')] diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py b/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py index 8ec12e308..ae2cbea19 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py @@ -52,7 +52,7 @@ class XunFeiLLMModelProParams(BaseForm): class XunFeiLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -65,7 +65,7 @@ class XunFeiLLMModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py b/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py index bf051c18a..f93922800 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py @@ -14,7 +14,7 @@ class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): spark_api_key = forms.PasswordInputField("API Key", required=True) spark_api_secret = forms.PasswordInputField('API Secret', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py b/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py index ec9478aae..99f2c6cc5 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py @@ -36,7 +36,7 @@ class XunFeiTTSModelCredential(BaseForm, BaseModelCredential): spark_api_key = forms.PasswordInputField("API Key", required=True) spark_api_secret = forms.PasswordInputField('API Secret', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -49,7 +49,7 @@ class XunFeiTTSModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py index 7cddb4f09..75e58ff53 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py @@ -9,7 +9,7 @@ from setting.models_provider.impl.local_model_provider.model.embedding import Lo class XinferenceEmbeddingModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object],model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/image.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/image.py index e2cbbb194..addddae53 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/image.py @@ -32,7 +32,7 @@ class XinferenceImageModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -45,7 +45,7 @@ class XinferenceImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) for chunk in res: print(chunk) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py index dc01c7906..48e540aac 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py @@ -28,7 +28,7 @@ class XinferenceLLMModelParams(BaseForm): class XinferenceLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -40,7 +40,7 @@ class XinferenceLLMModelCredential(BaseForm, BaseModelCredential): exist = provider.get_model_info_by_name(model_list, model_name) if len(exist) == 0: raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型") - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) return True diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/reranker.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/reranker.py index 87f27971e..856e28fd7 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/reranker.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/reranker.py @@ -17,7 +17,7 @@ from setting.models_provider.base_model_provider import BaseModelCredential, Val class XInferenceRerankerModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=True): if not model_type == 'RERANKER': raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/stt.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/stt.py index 7d19feaef..3f47be6fc 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/stt.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/stt.py @@ -11,7 +11,7 @@ class XInferenceSTTModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/tti.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/tti.py index eba50d022..a4d487e7c 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/tti.py @@ -50,7 +50,7 @@ class XinferenceTextToImageModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -63,7 +63,7 @@ class XinferenceTextToImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() print(res) except Exception as e: diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/tts.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/tts.py index 0bf3daadd..883519b7e 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/tts.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/tts.py @@ -29,7 +29,7 @@ class XInferenceTTSModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API 域名', required=True) api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -42,7 +42,7 @@ class XInferenceTTSModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py index 54bd19e14..c9092371e 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/image.py @@ -29,7 +29,7 @@ class ZhiPuImageModelParams(BaseForm): class ZhiPuImageModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -42,7 +42,7 @@ class ZhiPuImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) for chunk in res: print(chunk) diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py index 48c1194ee..55d1fad67 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py @@ -35,7 +35,7 @@ class ZhiPuLLMModelParams(BaseForm): class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential): - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -47,7 +47,7 @@ class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content='你好')]) except Exception as e: if isinstance(e, AppApiException): diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py index f951efd9e..0b58fa65a 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/tti.py @@ -29,7 +29,7 @@ class ZhiPuTTIModelParams(BaseForm): class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, raise_exception=False): model_type_list = provider.get_model_type_list() if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): @@ -42,7 +42,7 @@ class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential): else: return False try: - model = provider.get_model(model_type, model_name, model_credential) + model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() print(res) except Exception as e: diff --git a/apps/setting/serializers/provider_serializers.py b/apps/setting/serializers/provider_serializers.py index b8f314127..b966e331f 100644 --- a/apps/setting/serializers/provider_serializers.py +++ b/apps/setting/serializers/provider_serializers.py @@ -189,9 +189,11 @@ class ModelSerializer(serializers.Serializer): if QuerySet(Model).filter(user_id=self.data.get('user_id'), name=self.data.get('name')).exists(): raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在') + default_params = {item['field']: item['default_value'] for item in self.data.get('model_params_form')} ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(self.data.get('model_type'), self.data.get('model_name'), self.data.get('credential'), + default_params, raise_exception=True ) @@ -354,10 +356,12 @@ class ModelSerializer(serializers.Serializer): model=model) try: model.status = Status.SUCCESS + default_params = {item['field']: item['default_value'] for item in model.model_params_form} # 校验模型认证数据 provider_handler.is_valid_credential(model.model_type, instance.get("model_name"), credential, + default_params, raise_exception=True) except AppApiException as e: