mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-26 10:12:51 +00:00
feat: 对接ollama平台模型
This commit is contained in:
parent
71c29e81c9
commit
ba3e4e7556
|
|
@ -142,7 +142,7 @@ class BaseChatStep(IChatStep):
|
|||
chat_result = AIMessage(
|
||||
content="\n\n".join([paragraph.title + "\n" + paragraph.content for paragraph in paragraph_list]))
|
||||
else:
|
||||
chat_result = chat_model(message_list)
|
||||
chat_result = chat_model.invoke(message_list)
|
||||
chat_record_id = uuid.uuid1()
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(chat_result.content)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import json
|
|||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.core.cache import caches
|
||||
from django.db.models import QuerySet
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from rest_framework import serializers
|
||||
|
|
@ -30,7 +30,7 @@ from dataset.models import Paragraph, Document
|
|||
from setting.models import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
|
||||
chat_cache = cache
|
||||
chat_cache = caches['model_cache']
|
||||
|
||||
|
||||
class ChatInfo:
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ import os
|
|||
import re
|
||||
import uuid
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
from typing import Dict
|
||||
|
||||
from django.core import validators
|
||||
from django.core.cache import cache
|
||||
from django.core.cache import cache, caches
|
||||
from django.db import transaction, models
|
||||
from django.db.models import QuerySet, Q
|
||||
from rest_framework import serializers
|
||||
|
|
@ -32,14 +32,13 @@ from common.util.field_message import ErrMessage
|
|||
from common.util.file_util import get_file_content
|
||||
from common.util.lock import try_lock, un_lock
|
||||
from common.util.rsa_util import decrypt
|
||||
from common.util.split_model import flat_map
|
||||
from dataset.models import Document, Problem, Paragraph
|
||||
from dataset.serializers.paragraph_serializers import ParagraphSerializers
|
||||
from setting.models import Model
|
||||
from setting.models_provider.constants.model_provider_constants import ModelProvideConstants
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
chat_cache = cache
|
||||
chat_cache = caches['model_cache']
|
||||
|
||||
|
||||
class ChatSerializers(serializers.Serializer):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: mem_cache.py
|
||||
@date:2024/3/6 11:20
|
||||
@desc:
|
||||
"""
|
||||
from django.core.cache.backends.base import DEFAULT_TIMEOUT
|
||||
from django.core.cache.backends.locmem import LocMemCache
|
||||
|
||||
|
||||
class MemCache(LocMemCache):
|
||||
def __init__(self, name, params):
|
||||
super().__init__(name, params)
|
||||
|
||||
def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None):
|
||||
key = self.make_and_validate_key(key, version=version)
|
||||
pickled = value
|
||||
with self._lock:
|
||||
self._set(key, pickled, timeout)
|
||||
|
||||
def get(self, key, default=None, version=None):
|
||||
key = self.make_and_validate_key(key, version=version)
|
||||
with self._lock:
|
||||
if self._has_expired(key):
|
||||
self._delete(key)
|
||||
return default
|
||||
pickled = self._cache[key]
|
||||
self._cache.move_to_end(key, last=False)
|
||||
return pickled
|
||||
|
|
@ -8,7 +8,7 @@
|
|||
"""
|
||||
import types
|
||||
from smartdoc.const import CONFIG
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
|
||||
class EmbeddingModel:
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import threading
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from common.config.embedding_config import EmbeddingModel
|
||||
from common.util.common import sub_array
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import uuid
|
|||
from typing import Dict, List
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||||
|
||||
from common.db.search import native_search, generate_sql_by_query_dict
|
||||
from common.db.sql_execute import select_one, select_list
|
||||
|
|
|
|||
|
|
@ -9,11 +9,9 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import reduce
|
||||
from typing import Dict, List
|
||||
from typing import Dict
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
|
||||
|
||||
class IModelProvider(ABC):
|
||||
|
|
|
|||
|
|
@ -9,9 +9,11 @@
|
|||
from enum import Enum
|
||||
|
||||
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
|
||||
from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider
|
||||
from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider
|
||||
|
||||
|
||||
class ModelProvideConstants(Enum):
|
||||
model_azure_provider = AzureModelProvider()
|
||||
model_wenxin_provider = WenxinModelProvider()
|
||||
model_ollama_provider = OllamaModelProvider()
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
from langchain_community.chat_models import AzureChatOpenAI
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from common import froms
|
||||
|
|
@ -40,7 +40,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
return False
|
||||
try:
|
||||
model = AzureModelProvider().get_model(model_type, model_name, model_credential)
|
||||
model([HumanMessage(content='valid')])
|
||||
model.invoke([HumanMessage(content='valid')])
|
||||
except Exception as e:
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: __init__.py.py
|
||||
@date:2024/3/5 17:20
|
||||
@desc:
|
||||
"""
|
||||
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 10 KiB |
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: ollama_chat_model.py
|
||||
@date:2024/3/6 11:48
|
||||
@desc:
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
from langchain_community.chat_models import ChatOpenAI
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2', cache_dir="/opt/maxkb/model/tokenizer", resume_download=False,
|
||||
force_download=False)
|
||||
|
||||
|
||||
class OllamaChatModel(ChatOpenAI):
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
return len(tokenizer.encode(text))
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: maxkb
|
||||
@Author:虎
|
||||
@file: ollama_model_provider.py
|
||||
@date:2024/3/5 17:23
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from common import froms
|
||||
from common.exception.app_exception import AppApiException
|
||||
from common.froms import BaseForm
|
||||
from common.util.file_util import get_file_content
|
||||
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \
|
||||
BaseModelCredential
|
||||
from setting.models_provider.impl.ollama_model_provider.model.ollama_chat_model import OllamaChatModel
|
||||
from smartdoc.conf import PROJECT_DIR
|
||||
|
||||
|
||||
class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
|
||||
model_type_list = OllamaModelProvider().get_model_type_list()
|
||||
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
|
||||
raise AppApiException(500, f'{model_type} 模型类型不支持')
|
||||
|
||||
if model_name not in model_dict:
|
||||
raise AppApiException(500, f'{model_name} 模型名称不支持')
|
||||
|
||||
for key in ['api_key']:
|
||||
if key not in model_credential:
|
||||
if raise_exception:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
else:
|
||||
return False
|
||||
try:
|
||||
OllamaModelProvider().get_model(model_type, model_name, model_credential).invoke(
|
||||
[HumanMessage(content='valid')])
|
||||
except Exception as e:
|
||||
if raise_exception:
|
||||
raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确")
|
||||
return True
|
||||
|
||||
def encryption_dict(self, model_info: Dict[str, object]):
|
||||
return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))}
|
||||
|
||||
def build_model(self, model_info: Dict[str, object]):
|
||||
for key in ['api_key', 'model']:
|
||||
if key not in model_info:
|
||||
raise AppApiException(500, f'{key} 字段为必填字段')
|
||||
self.api_key = model_info.get('api_key')
|
||||
return self
|
||||
|
||||
api_base = froms.TextInputField('API 域名', required=True)
|
||||
api_key = froms.PasswordInputField('API Key', required=True)
|
||||
|
||||
|
||||
ollama_llm_model_credential = OllamaLLMModelCredential()
|
||||
|
||||
model_dict = {
|
||||
'llama2': ModelInfo(
|
||||
'llama2',
|
||||
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 7B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'llama2:13b': ModelInfo(
|
||||
'llama2:13b',
|
||||
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 13B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'llama2:70b': ModelInfo(
|
||||
'llama2:70b',
|
||||
'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 70B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'llama2-chinese:13b': ModelInfo(
|
||||
'llama2-chinese:13b',
|
||||
'由于Llama2本身的中文对齐较弱,我们采用中文指令集,对meta-llama/Llama-2-13b-chat-hf进行LoRA微调,使其具备较强的中文对话能力。',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
'llama2-chinese:13b-maxkb': ModelInfo(
|
||||
'llama2-chinese:13b-maxkb',
|
||||
'由于Llama2本身的中文对齐较弱,我们采用中文指令集,对meta-llama/Llama-2-13b-chat-hf进行LoRA微调,使其具备较强的中文对话能力。fi2cloud专用',
|
||||
ModelTypeConst.LLM, ollama_llm_model_credential),
|
||||
}
|
||||
|
||||
|
||||
class OllamaModelProvider(IModelProvider):
|
||||
def get_model_provide_info(self):
|
||||
return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon',
|
||||
'ollama_icon_svg')))
|
||||
|
||||
def get_model_type_list(self):
|
||||
return [{'key': "大语言模型", 'value': "LLM"}]
|
||||
|
||||
def get_model_list(self, model_type):
|
||||
if model_type is None:
|
||||
raise AppApiException(500, '模型类型不能为空')
|
||||
return [model_dict.get(key).to_dict() for key in
|
||||
list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
|
||||
|
||||
def get_model_credential(self, model_type, model_name):
|
||||
if model_name in model_dict:
|
||||
return model_dict.get(model_name).model_credential
|
||||
raise AppApiException(500, f'不支持的模型:{model_name}')
|
||||
|
||||
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseChatModel:
|
||||
return OllamaChatModel(model=model_name, openai_api_base=model_credential.get('api_base'),
|
||||
openai_api_key=model_credential.get('api_key'))
|
||||
|
||||
def get_dialogue_number(self):
|
||||
return 2
|
||||
|
|
@ -9,7 +9,7 @@
|
|||
from typing import Optional, List, Any, Iterator, cast
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models import QianfanChatEndpoint
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.load import dumpd
|
||||
from langchain.schema import LLMResult
|
||||
|
|
|
|||
|
|
@ -9,8 +9,7 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
from langchain.chat_models import QianfanChatEndpoint
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import convert_message_to_dict
|
||||
from langchain_community.chat_models import QianfanChatEndpoint
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from common import froms
|
||||
|
|
@ -39,7 +38,8 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
else:
|
||||
return False
|
||||
try:
|
||||
WenxinModelProvider().get_model(model_type, model_name, model_credential)([HumanMessage(content='valid')])
|
||||
WenxinModelProvider().get_model(model_type, model_name, model_credential).invoke(
|
||||
[HumanMessage(content='valid')])
|
||||
except Exception as e:
|
||||
if raise_exception:
|
||||
raise AppApiException(500, "校验失败,请检查 api_key secret_key 是否正确")
|
||||
|
|
|
|||
|
|
@ -89,6 +89,9 @@ CACHES = {
|
|||
"default": {
|
||||
'BACKEND': 'django.core.cache.backends.locmem.LocMemCache',
|
||||
},
|
||||
'model_cache': {
|
||||
'BACKEND': 'common.cache.mem_cache.MemCache'
|
||||
},
|
||||
# 存储用户信息
|
||||
'user_cache': {
|
||||
'BACKEND': 'common.cache.file_cache.FileCache',
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ djangorestframework = "3.14.0"
|
|||
drf-yasg = "1.21.7"
|
||||
django-filter = "23.2"
|
||||
elasticsearch = "8.9.0"
|
||||
langchain = "^0.0.321"
|
||||
langchain = "^0.1.11"
|
||||
psycopg2-binary = "2.9.7"
|
||||
jieba = "^0.42.1"
|
||||
diskcache = "^5.6.3"
|
||||
|
|
@ -21,13 +21,13 @@ filetype = "^1.2.0"
|
|||
chardet = "^5.2.0"
|
||||
sentence-transformers = "^2.2.2"
|
||||
blinker = "^1.6.3"
|
||||
openai = "^0.28.1"
|
||||
openai = "^^1.13.3"
|
||||
tiktoken = "^0.5.1"
|
||||
qianfan = "^0.1.1"
|
||||
pycryptodome = "^3.19.0"
|
||||
beautifulsoup4 = "^4.12.2"
|
||||
html2text = "^2024.2.26"
|
||||
|
||||
langchain-openai = "^0.0.8"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
|
|
|
|||
Loading…
Reference in New Issue