feat: 对接ollama平台模型

This commit is contained in:
shaohuzhang1 2024-03-06 13:43:45 +08:00
parent 71c29e81c9
commit ba3e4e7556
18 changed files with 203 additions and 22 deletions

View File

@ -142,7 +142,7 @@ class BaseChatStep(IChatStep):
chat_result = AIMessage(
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
else:
chat_result = chat_model(message_list)
chat_result = chat_model.invoke(message_list)
chat_record_id = uuid.uuid1()
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(chat_result.content)

View File

@ -10,7 +10,7 @@ import json
from typing import List
from uuid import UUID
from django.core.cache import cache
from django.core.cache import caches
from django.db.models import QuerySet
from langchain.chat_models.base import BaseChatModel
from rest_framework import serializers
@ -30,7 +30,7 @@ from dataset.models import Paragraph, Document
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
chat_cache = cache
chat_cache = caches['model_cache']
class ChatInfo:

View File

@ -12,10 +12,10 @@ import os
import re
import uuid
from functools import reduce
from typing import Dict, List
from typing import Dict
from django.core import validators
from django.core.cache import cache
from django.core.cache import cache, caches
from django.db import transaction, models
from django.db.models import QuerySet, Q
from rest_framework import serializers
@ -32,14 +32,13 @@ from common.util.field_message import ErrMessage
from common.util.file_util import get_file_content
from common.util.lock import try_lock, un_lock
from common.util.rsa_util import decrypt
from common.util.split_model import flat_map
from dataset.models import Document, Problem, Paragraph
from dataset.serializers.paragraph_serializers import ParagraphSerializers
from setting.models import Model
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
from smartdoc.conf import PROJECT_DIR
chat_cache = cache
chat_cache = caches['model_cache']
class ChatSerializers(serializers.Serializer):

31
apps/common/cache/mem_cache.py vendored Normal file
View File

@ -0,0 +1,31 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file mem_cache.py
@date2024/3/6 11:20
@desc:
"""
from django.core.cache.backends.base import DEFAULT_TIMEOUT
from django.core.cache.backends.locmem import LocMemCache
class MemCache(LocMemCache):
def __init__(self, name, params):
super().__init__(name, params)
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
key = self.make_and_validate_key(key, version=version)
pickled = value
with self._lock:
self._set(key, pickled, timeout)
def get(self, key, default=None, version=None):
key = self.make_and_validate_key(key, version=version)
with self._lock:
if self._has_expired(key):
self._delete(key)
return default
pickled = self._cache[key]
self._cache.move_to_end(key, last=False)
return pickled

View File

@ -8,7 +8,7 @@
"""
import types
from smartdoc.const import CONFIG
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings
class EmbeddingModel:

View File

@ -10,7 +10,7 @@ import threading
from abc import ABC, abstractmethod
from typing import List, Dict
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings
from common.config.embedding_config import EmbeddingModel
from common.util.common import sub_array

View File

@ -12,7 +12,7 @@ import uuid
from typing import Dict, List
from django.db.models import QuerySet
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.embeddings import HuggingFaceEmbeddings
from common.db.search import native_search, generate_sql_by_query_dict
from common.db.sql_execute import select_one, select_list

View File

@ -9,11 +9,9 @@
from abc import ABC, abstractmethod
from enum import Enum
from functools import reduce
from typing import Dict, List
from typing import Dict
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage
from langchain.schema.language_model import LanguageModelInput
class IModelProvider(ABC):

View File

@ -9,9 +9,11 @@
from enum import Enum
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
class ModelProvideConstants(Enum):
model_azure_provider = AzureModelProvider()
model_wenxin_provider = WenxinModelProvider()
model_ollama_provider = OllamaModelProvider()

View File

@ -9,7 +9,7 @@
import os
from typing import Dict
from langchain.chat_models import AzureChatOpenAI
from langchain_community.chat_models import AzureChatOpenAI
from langchain.schema import HumanMessage
from common import froms
@ -40,7 +40,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
return False
try:
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
model([HumanMessage(content='valid')])
model.invoke([HumanMessage(content='valid')])
except Exception as e:
if isinstance(e, AppApiException):
raise e

View File

@ -0,0 +1,8 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file __init__.py.py
@date2024/3/5 17:20
@desc:
"""

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 10 KiB

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file ollama_chat_model.py
@date2024/3/6 11:48
@desc:
"""
from typing import List
from langchain_community.chat_models import ChatOpenAI
from langchain_core.messages import BaseMessage, get_buffer_string
from transformers import GPT2TokenizerFast
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
force_download=False)
class OllamaChatModel(ChatOpenAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
return len(tokenizer.encode(text))

View File

@ -0,0 +1,113 @@
# coding=utf-8
"""
@project: maxkb
@Author
@file ollama_model_provider.py
@date2024/3/5 17:23
@desc:
"""
import os
from typing import Dict
from langchain.chat_models.base import BaseChatModel
from langchain.schema import HumanMessage
from common import froms
from common.exception.app_exception import AppApiException
from common.froms import BaseForm
from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
BaseModelCredential
from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel
from smartdoc.conf import PROJECT_DIR
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
model_type_list = OllamaModelProvider().get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(500, f'{model_type} 模型类型不支持')
if model_name not in model_dict:
raise AppApiException(500, f'{model_name} 模型名称不支持')
for key in ['api_key']:
if key not in model_credential:
if raise_exception:
raise AppApiException(500, f'{key} 字段为必填字段')
else:
return False
try:
OllamaModelProvider().get_model(model_type, model_name, model_credential).invoke(
[HumanMessage(content='valid')])
except Exception as e:
if raise_exception:
raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确")
return True
def encryption_dict(self, model_info: Dict[str, object]):
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
def build_model(self, model_info: Dict[str, object]):
for key in ['api_key', 'model']:
if key not in model_info:
raise AppApiException(500, f'{key} 字段为必填字段')
self.api_key = model_info.get('api_key')
return self
api_base = froms.TextInputField('API 域名', required=True)
api_key = froms.PasswordInputField('API Key', required=True)
ollama_llm_model_credential = OllamaLLMModelCredential()
model_dict = {
'llama2': ModelInfo(
'llama2',
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 7B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
ModelTypeConst.LLM, ollama_llm_model_credential),
'llama2:13b': ModelInfo(
'llama2:13b',
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 13B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
ModelTypeConst.LLM, ollama_llm_model_credential),
'llama2:70b': ModelInfo(
'llama2:70b',
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 70B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
ModelTypeConst.LLM, ollama_llm_model_credential),
'llama2-chinese:13b': ModelInfo(
'llama2-chinese:13b',
'由于Llama2本身的中文对齐较弱我们采用中文指令集对meta-llama/Llama-2-13b-chat-hf进行LoRA微调使其具备较强的中文对话能力。',
ModelTypeConst.LLM, ollama_llm_model_credential),
'llama2-chinese:13b-maxkb': ModelInfo(
'llama2-chinese:13b-maxkb',
'由于Llama2本身的中文对齐较弱我们采用中文指令集对meta-llama/Llama-2-13b-chat-hf进行LoRA微调使其具备较强的中文对话能力。fi2cloud专用',
ModelTypeConst.LLM, ollama_llm_model_credential),
}
class OllamaModelProvider(IModelProvider):
def get_model_provide_info(self):
return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content(
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
'ollama_icon_svg')))
def get_model_type_list(self):
return [{'key': "大语言模型", 'value': "LLM"}]
def get_model_list(self, model_type):
if model_type is None:
raise AppApiException(500, '模型类型不能为空')
return [model_dict.get(key).to_dict() for key in
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
def get_model_credential(self, model_type, model_name):
if model_name in model_dict:
return model_dict.get(model_name).model_credential
raise AppApiException(500, f'不支持的模型:{model_name}')
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
return OllamaChatModel(model=model_name, openai_api_base=model_credential.get('api_base'),
openai_api_key=model_credential.get('api_key'))
def get_dialogue_number(self):
return 2

View File

@ -9,7 +9,7 @@
from typing import Optional, List, Any, Iterator, cast
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models import QianfanChatEndpoint
from langchain_community.chat_models import QianfanChatEndpoint
from langchain.chat_models.base import BaseChatModel
from langchain.load import dumpd
from langchain.schema import LLMResult

View File

@ -9,8 +9,7 @@
import os
from typing import Dict
from langchain.chat_models import QianfanChatEndpoint
from langchain.chat_models.baidu_qianfan_endpoint import convert_message_to_dict
from langchain_community.chat_models import QianfanChatEndpoint
from langchain.schema import HumanMessage
from common import froms
@ -39,7 +38,8 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
else:
return False
try:
WenxinModelProvider().get_model(model_type, model_name, model_credential)([HumanMessage(content='valid')])
WenxinModelProvider().get_model(model_type, model_name, model_credential).invoke(
[HumanMessage(content='valid')])
except Exception as e:
if raise_exception:
raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确")

View File

@ -89,6 +89,9 @@ CACHES = {
"default": {
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
},
'model_cache': {
'BACKEND': 'common.cache.mem_cache.MemCache'
},
# 存储用户信息
'user_cache': {
'BACKEND': 'common.cache.file_cache.FileCache',

View File

@ -12,7 +12,7 @@ djangorestframework = "3.14.0"
drf-yasg = "1.21.7"
django-filter = "23.2"
elasticsearch = "8.9.0"
langchain = "^0.0.321"
langchain = "^0.1.11"
psycopg2-binary = "2.9.7"
jieba = "^0.42.1"
diskcache = "^5.6.3"
@ -21,13 +21,13 @@ filetype = "^1.2.0"
chardet = "^5.2.0"
sentence-transformers = "^2.2.2"
blinker = "^1.6.3"
openai = "^0.28.1"
openai = "^^1.13.3"
tiktoken = "^0.5.1"
qianfan = "^0.1.1"
pycryptodome = "^3.19.0"
beautifulsoup4 = "^4.12.2"
html2text = "^2024.2.26"
langchain-openai = "^0.0.8"
[build-system]
requires = ["poetry-core"]