diff --git a/apps/chat/serializers/chat.py b/apps/chat/serializers/chat.py index 1a703a41c..47c658ee9 100644 --- a/apps/chat/serializers/chat.py +++ b/apps/chat/serializers/chat.py @@ -158,9 +158,12 @@ class PromptGenerateSerializer(serializers.Serializer): q = prompt.replace("{userInput}", message) messages[-1]['content'] = q - model_exist = QuerySet(Model).filter(workspace_id=workspace_id, id=model_id).exists() + model_exist = QuerySet(Model).filter(workspace_id=workspace_id, + id=model_id, + model_type = "LLM" + ).exists() if not model_exist: - raise Exception(_("model does not exists")) + raise Exception(_("Model does not exists or is not an LLM model")) def process(): model = get_model_instance_by_model_workspace_id(model_id=model_id, workspace_id=workspace_id) diff --git a/apps/models_provider/impl/vllm_model_provider/model/reranker.py b/apps/models_provider/impl/vllm_model_provider/model/reranker.py index e11b3aa9b..1736578e7 100644 --- a/apps/models_provider/impl/vllm_model_provider/model/reranker.py +++ b/apps/models_provider/impl/vllm_model_provider/model/reranker.py @@ -28,10 +28,11 @@ class VllmBgeReranker(MaxKBBaseModel, BaseDocumentCompressor): @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + r_url = model_credential.get('api_url')[:-3] if model_credential.get('api_url').endswith('/v1') else model_credential.get('api_url') return VllmBgeReranker( model=model_name, api_key=model_credential.get('api_key'), - api_url=model_credential.get('api_url'), + api_url=r_url, params=model_kwargs, **model_kwargs ) diff --git a/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py b/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py index 62075a611..f57c046e1 100644 --- a/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py +++ b/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py @@ -44,7 +44,9 @@ class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText): self.speech_to_text(audio_file) def speech_to_text(self, audio_file): - base_url = f"{self.api_url}/v1" + + base_url = self.api_url if self.api_url.endswith('v1') else f"{self.api_url}/v1" + try: client = OpenAI( api_key=self.api_key,