diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py index 89cb3ed35..fb1e77cdd 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -18,9 +18,10 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel def get_base_url(url: str): parse = urlparse(url) - return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='', - query='', - fragment='').geturl() + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): @@ -28,7 +29,8 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): api_base = model_credential.get('api_base', '') base_url = get_base_url(api_base) - return OllamaChatModel(model=model_name, openai_api_base=(base_url + '/v1'), + base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') + return OllamaChatModel(model=model_name, openai_api_base=base_url, openai_api_key=model_credential.get('api_key')) def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: diff --git a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py index eb01d38d7..6f0fc942c 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py @@ -113,9 +113,10 @@ model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_ def get_base_url(url: str): parse = urlparse(url) - return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='', - query='', - fragment='').geturl() + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url def convert_to_down_model_chunk(row_str: str, chunk_index: int):