From 38b58d3e5e19947f76a118c9821ac8b0fa5d49e2 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Tue, 6 Aug 2024 11:34:05 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dollama=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B=E4=BA=8C=E7=BA=A7path=E6=97=A0=E6=B3=95=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E9=97=AE=E9=A2=98=20#805=20(#930)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 62b12c31247cdb6d3fc6d07267a8cfbe1c676358) --- .../impl/ollama_model_provider/model/llm.py | 10 ++++++---- .../ollama_model_provider/ollama_model_provider.py | 7 ++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py index 89cb3ed35..fb1e77cdd 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -18,9 +18,10 @@ from setting.models_provider.base_model_provider import MaxKBBaseModel def get_base_url(url: str): parse = urlparse(url) - return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='', - query='', - fragment='').geturl() + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): @@ -28,7 +29,8 @@ class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): api_base = model_credential.get('api_base', '') base_url = get_base_url(api_base) - return OllamaChatModel(model=model_name, openai_api_base=(base_url + '/v1'), + base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') + return OllamaChatModel(model=model_name, openai_api_base=base_url, openai_api_key=model_credential.get('api_key')) def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: diff --git a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py index eb01d38d7..6f0fc942c 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py @@ -113,9 +113,10 @@ model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_ def get_base_url(url: str): parse = urlparse(url) - return ParseResult(scheme=parse.scheme, netloc=parse.netloc, path='', params='', - query='', - fragment='').geturl() + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url def convert_to_down_model_chunk(row_str: str, chunk_index: int):