MaxKB/apps/models_provider/impl/ollama_model_provider/model/llm.py
2025-04-17 18:01:33 +08:00

49 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
"""
@project: maxkb
@Author
@file llm.py
@date2024/3/6 11:48
@desc:
"""
from typing import List, Dict
from urllib.parse import urlparse, ParseResult
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_ollama.chat_models import ChatOllama
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel
def get_base_url(url: str):
parse = urlparse(url)
result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='',
query='',
fragment='').geturl()
return result_url[:-1] if result_url.endswith("/") else result_url
class OllamaChatModel(MaxKBBaseModel, ChatOllama):
@staticmethod
def is_cache_model():
return False
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
api_base = model_credential.get('api_base', '')
base_url = get_base_url(api_base)
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
return OllamaChatModel(model=model_name, base_url=base_url,
stream=True, **optional_params)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))