diff --git a/README.md b/README.md index 072073ae3..5aeb07d6a 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@
-
+
diff --git a/apps/dataset/serializers/paragraph_serializers.py b/apps/dataset/serializers/paragraph_serializers.py
index 589710ce0..3188766a7 100644
--- a/apps/dataset/serializers/paragraph_serializers.py
+++ b/apps/dataset/serializers/paragraph_serializers.py
@@ -15,7 +15,7 @@ from drf_yasg import openapi
from rest_framework import serializers
from common.db.search import page_search
-from common.event.listener_manage import ListenerManagement, UpdateEmbeddingDocumentIdArgs, UpdateEmbeddingDatasetIdArgs
+from common.event.listener_manage import ListenerManagement, UpdateEmbeddingDocumentIdArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
@@ -284,6 +284,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
paragraph_id_list = instance.get("id_list")
QuerySet(Paragraph).filter(id__in=paragraph_id_list).delete()
QuerySet(ProblemParagraphMapping).filter(paragraph_id__in=paragraph_id_list).delete()
+ update_document_char_length(self.data.get('document_id'))
# 删除向量库
ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_id_list)
return True
@@ -370,6 +371,8 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
target_document_id, target_dataset_id))
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)
+ update_document_char_length(document_id)
+ update_document_char_length(target_document_id)
@staticmethod
def update_problem_paragraph_mapping(target_document_id: str, problem_paragraph_mapping):
@@ -527,6 +530,7 @@ class ParagraphSerializers(ApiMixin, serializers.Serializer):
paragraph_id = self.data.get('paragraph_id')
QuerySet(Paragraph).filter(id=paragraph_id).delete()
QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete()
+ update_document_char_length(self.data.get('document_id'))
ListenerManagement.delete_embedding_by_paragraph_signal.send(paragraph_id)
@staticmethod
diff --git a/apps/setting/models_provider/constants/model_provider_constants.py b/apps/setting/models_provider/constants/model_provider_constants.py
index 1e587bba2..0a7565f38 100644
--- a/apps/setting/models_provider/constants/model_provider_constants.py
+++ b/apps/setting/models_provider/constants/model_provider_constants.py
@@ -17,6 +17,7 @@ from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import
from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider
from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
+from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider
class ModelProvideConstants(Enum):
@@ -29,3 +30,4 @@ class ModelProvideConstants(Enum):
model_zhipu_provider = ZhiPuModelProvider()
model_xf_provider = XunFeiModelProvider()
model_deepseek_provider = DeepSeekModelProvider()
+ model_gemini_provider = GeminiModelProvider()
diff --git a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py
index 5731d7e38..3164dd8ea 100644
--- a/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py
+++ b/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py
@@ -21,43 +21,6 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro
from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel
from smartdoc.conf import PROJECT_DIR
-"""
-class AzureLLMModelCredential(BaseForm, BaseModelCredential):
-
- def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
- model_type_list = AzureModelProvider().get_model_type_list()
- if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
- raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
-
- for key in ['api_base', 'api_key', 'deployment_name']:
- if key not in model_credential:
- if raise_exception:
- raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
- else:
- return False
- try:
- model = AzureModelProvider().get_model(model_type, model_name, model_credential)
- model.invoke([HumanMessage(content='你好')])
- except Exception as e:
- if isinstance(e, AppApiException):
- raise e
- if raise_exception:
- raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确')
- else:
- return False
-
- return True
-
- def encryption_dict(self, model: Dict[str, object]):
- return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
-
- api_base = forms.TextInputField('API 版本 (api_version)', required=True)
-
- api_key = forms.PasswordInputField("API Key(API 密钥)", required=True)
-
- deployment_name = forms.TextInputField("部署名(deployment_name)", required=True)
-"""
-
class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential):
@@ -97,8 +60,6 @@ class DefaultAzureLLMModelCredential(BaseForm, BaseModelCredential):
deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True)
-# azure_llm_model_credential: AzureLLMModelCredential = AzureLLMModelCredential()
-
base_azure_llm_model_credential = DefaultAzureLLMModelCredential()
model_dict = {
@@ -114,7 +75,6 @@ class AzureModelProvider(IModelProvider):
return 3
def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> AzureChatModel:
- model_info: ModelInfo = model_dict.get(model_name)
azure_chat_open_ai = AzureChatModel(
azure_endpoint=model_credential.get('api_base'),
openai_api_version=model_credential.get('api_version', '2024-02-15-preview'),
diff --git a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py
index f11249de0..6388dbde2 100644
--- a/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py
+++ b/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py
@@ -16,9 +16,15 @@ from common.config.tokenizer_manage_config import TokenizerManage
class AzureChatModel(AzureChatOpenAI):
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
- tokenizer = TokenizerManage.get_tokenizer()
- return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
+ try:
+ return super().get_num_tokens_from_messages(messages)
+ except Exception as e:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
- tokenizer = TokenizerManage.get_tokenizer()
- return len(tokenizer.encode(text))
+ try:
+ return super().get_num_tokens(text)
+ except Exception as e:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return len(tokenizer.encode(text))
diff --git a/apps/setting/models_provider/impl/gemini_model_provider/__init__.py b/apps/setting/models_provider/impl/gemini_model_provider/__init__.py
new file mode 100644
index 000000000..43fd3dd05
--- /dev/null
+++ b/apps/setting/models_provider/impl/gemini_model_provider/__init__.py
@@ -0,0 +1,8 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+"""
+@Project :MaxKB
+@File :__init__.py.py
+@Author :Brian Yang
+@Date :5/13/24 7:40 AM
+"""
diff --git a/apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py b/apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py
new file mode 100644
index 000000000..5ddddf782
--- /dev/null
+++ b/apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py
@@ -0,0 +1,99 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+"""
+@Project :MaxKB
+@File :gemini_model_provider.py
+@Author :Brian Yang
+@Date :5/13/24 7:47 AM
+"""
+import os
+from typing import Dict
+
+from langchain.schema import HumanMessage
+
+from common import forms
+from common.exception.app_exception import AppApiException
+from common.forms import BaseForm
+from common.util.file_util import get_file_content
+from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, BaseModelCredential, \
+ ModelInfo, ModelTypeConst, ValidCode
+from setting.models_provider.impl.gemini_model_provider.model.gemini_chat_model import GeminiChatModel
+from smartdoc.conf import PROJECT_DIR
+
+
+class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
+
+ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], raise_exception=False):
+ model_type_list = GeminiModelProvider().get_model_type_list()
+ if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
+ raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
+
+ for key in ['api_key']:
+ if key not in model_credential:
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
+ else:
+ return False
+ try:
+ model = GeminiModelProvider().get_model(model_type, model_name, model_credential)
+ model.invoke([HumanMessage(content='你好')])
+ except Exception as e:
+ if isinstance(e, AppApiException):
+ raise e
+ if raise_exception:
+ raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
+ else:
+ return False
+ return True
+
+ def encryption_dict(self, model: Dict[str, object]):
+ return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
+
+ api_key = forms.PasswordInputField('API Key', required=True)
+
+
+gemini_llm_model_credential = GeminiLLMModelCredential()
+
+model_dict = {
+ 'gemini-1.0-pro': ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新',
+ ModelTypeConst.LLM,
+ gemini_llm_model_credential,
+ ),
+ 'gemini-1.0-pro-vision': ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型,随Google更新而更新',
+ ModelTypeConst.LLM,
+ gemini_llm_model_credential,
+ ),
+}
+
+
+class GeminiModelProvider(IModelProvider):
+
+ def get_dialogue_number(self):
+ return 3
+
+ def get_model(self, model_type, model_name, model_credential: Dict[str, object],
+ **model_kwargs) -> GeminiChatModel:
+ gemini_chat = GeminiChatModel(
+ model=model_name,
+ google_api_key=model_credential.get('api_key')
+ )
+ return gemini_chat
+
+ def get_model_credential(self, model_type, model_name):
+ if model_name in model_dict:
+ return model_dict.get(model_name).model_credential
+ return gemini_llm_model_credential
+
+ def get_model_provide_info(self):
+ return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content(
+ os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'gemini_model_provider', 'icon',
+ 'gemini_icon_svg')))
+
+ def get_model_list(self, model_type: str):
+ if model_type is None:
+ raise AppApiException(500, '模型类型不能为空')
+ return [model_dict.get(key).to_dict() for key in
+ list(filter(lambda key: model_dict.get(key).model_type == model_type, model_dict.keys()))]
+
+ def get_model_type_list(self):
+ return [{'key': "大语言模型", 'value': "LLM"}]
diff --git a/apps/setting/models_provider/impl/gemini_model_provider/icon/gemini_icon_svg b/apps/setting/models_provider/impl/gemini_model_provider/icon/gemini_icon_svg
new file mode 100644
index 000000000..00c48a359
--- /dev/null
+++ b/apps/setting/models_provider/impl/gemini_model_provider/icon/gemini_icon_svg
@@ -0,0 +1,10 @@
+
diff --git a/apps/setting/models_provider/impl/gemini_model_provider/model/gemini_chat_model.py b/apps/setting/models_provider/impl/gemini_model_provider/model/gemini_chat_model.py
new file mode 100644
index 000000000..7a972d9d5
--- /dev/null
+++ b/apps/setting/models_provider/impl/gemini_model_provider/model/gemini_chat_model.py
@@ -0,0 +1,30 @@
+#!/usr/bin/env python
+# -*- coding: UTF-8 -*-
+"""
+@Project :MaxKB
+@File :gemini_chat_model.py
+@Author :Brian Yang
+@Date :5/13/24 7:40 AM
+"""
+from typing import List
+
+from langchain_core.messages import BaseMessage, get_buffer_string
+from langchain_google_genai import ChatGoogleGenerativeAI
+
+from common.config.tokenizer_manage_config import TokenizerManage
+
+
+class GeminiChatModel(ChatGoogleGenerativeAI):
+ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+ try:
+ return super().get_num_tokens_from_messages(messages)
+ except Exception as e:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
+
+ def get_num_tokens(self, text: str) -> int:
+ try:
+ return super().get_num_tokens(text)
+ except Exception as e:
+ tokenizer = TokenizerManage.get_tokenizer()
+ return len(tokenizer.encode(text))
diff --git a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py
index 3827bdc68..73239921e 100644
--- a/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py
+++ b/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py
@@ -35,7 +35,7 @@ class OllamaLLMModelCredential(BaseForm, BaseModelCredential):
model_list = OllamaModelProvider.get_base_model_list(model_credential.get('api_base'))
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
- exist = [model for model in model_list.get('models') if
+ exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if
model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name]
if len(exist) == 0:
raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型")
diff --git a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py
index 983561b95..324e851cd 100644
--- a/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py
+++ b/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py
@@ -60,6 +60,17 @@ model_dict = {
'gpt-3.5-turbo': ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
openai_llm_model_credential,
),
+ 'gpt-4': ModelInfo('gpt-4', '最新的gpt-4,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential,
+ ),
+ 'gpt-4o': ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新',
+ ModelTypeConst.LLM, openai_llm_model_credential,
+ ),
+ 'gpt-4-turbo': ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
+ openai_llm_model_credential,
+ ),
+ 'gpt-4-turbo-preview': ModelInfo('gpt-4-turbo-preview', '最新的gpt-4-turbo-preview,随OpenAI调整而更新',
+ ModelTypeConst.LLM, openai_llm_model_credential,
+ ),
'gpt-3.5-turbo-0125': ModelInfo('gpt-3.5-turbo-0125',
'2024年1月25日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM,
openai_llm_model_credential,
@@ -72,14 +83,10 @@ model_dict = {
'[Legacy] 2023年6月13日的gpt-3.5-turbo快照,将于2024年6月13日弃用',
ModelTypeConst.LLM, openai_llm_model_credential,
),
- 'gpt-4': ModelInfo('gpt-4', '最新的gpt-4,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential,
- ),
- 'gpt-4-turbo': ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM,
- openai_llm_model_credential,
- ),
- 'gpt-4-turbo-preview': ModelInfo('gpt-4-turbo-preview', '最新的gpt-4-turbo-preview,随OpenAI调整而更新',
- ModelTypeConst.LLM, openai_llm_model_credential,
- ),
+ 'gpt-4o-2024-05-13': ModelInfo('gpt-4o-2024-05-13',
+ '2024年5月13日的gpt-4o快照,支持上下文长度128,000 tokens',
+ ModelTypeConst.LLM, openai_llm_model_credential,
+ ),
'gpt-4-turbo-2024-04-09': ModelInfo('gpt-4-turbo-2024-04-09',
'2024年4月9日的gpt-4-turbo快照,支持上下文长度128,000 tokens',
ModelTypeConst.LLM, openai_llm_model_credential,
diff --git a/pyproject.toml b/pyproject.toml
index f70b250bb..3eddbace0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -37,6 +37,7 @@ zhipuai = "^2.0.1"
httpx = "^0.27.0"
httpx-sse = "^0.4.0"
websocket-client = "^1.7.0"
+langchain-google-genai = "^1.0.3"
[build-system]
requires = ["poetry-core"]
diff --git a/ui/package.json b/ui/package.json
index 74b34f093..229b9d3a7 100644
--- a/ui/package.json
+++ b/ui/package.json
@@ -31,7 +31,7 @@
"markdown-it-sup": "^1.0.0",
"markdown-it-task-lists": "^2.1.1",
"markdown-it-toc-done-right": "^4.2.0",
- "md-editor-v3": "^4.12.1",
+ "md-editor-v3": "4.12.1",
"medium-zoom": "^1.1.0",
"mermaid": "^10.9.0",
"mitt": "^3.0.0",
diff --git a/ui/src/views/chat/index.vue b/ui/src/views/chat/index.vue
index edd8f67b5..ecb00d3db 100644
--- a/ui/src/views/chat/index.vue
+++ b/ui/src/views/chat/index.vue
@@ -135,6 +135,7 @@ onMounted(() => {
overflow: hidden;
position: relative;
}
+
&__footer {
background: #f3f7f9;
height: 80px;
diff --git a/ui/src/views/dataset/component/EditParagraphDialog.vue b/ui/src/views/dataset/component/EditParagraphDialog.vue
index 15f82307a..ce53dd781 100644
--- a/ui/src/views/dataset/component/EditParagraphDialog.vue
+++ b/ui/src/views/dataset/component/EditParagraphDialog.vue
@@ -95,9 +95,9 @@ function delProblemHandle(item: any, index: number) {
detail.value.problem_list.splice(index, 1)
}
function addProblemHandle() {
- if (problemValue.value) {
+ if (problemValue.value.trim()) {
detail.value?.problem_list?.push({
- content: problemValue.value
+ content: problemValue.value.trim()
})
problemValue.value = ''
isAddProblem.value = false
diff --git a/ui/src/views/dataset/step/StepSecond.vue b/ui/src/views/dataset/step/StepSecond.vue
index eae0fa3f9..59abd10aa 100644
--- a/ui/src/views/dataset/step/StepSecond.vue
+++ b/ui/src/views/dataset/step/StepSecond.vue
@@ -136,10 +136,10 @@ function changeHandle(val: boolean) {
const list = paragraphList.value
list.map((item: any) => {
item.content.map((v: any) => {
- v['problem_list'] = v.title
+ v['problem_list'] = v.title.trim()
? [
{
- content: v.title
+ content: v.title.trim()
}
]
: []
@@ -173,17 +173,17 @@ function splitDocument() {
if (checkedConnect.value) {
list.map((item: any) => {
item.content.map((v: any) => {
- v['problem_list'] = v.title
+ v['problem_list'] = v.title.trim()
? [
{
- content: v.title
+ content: v.title.trim()
}
]
: []
})
})
}
- paragraphList.value = res.data
+ paragraphList.value = list
loading.value = false
})
.catch(() => {
diff --git a/ui/src/views/document/component/ImportDocumentDialog.vue b/ui/src/views/document/component/ImportDocumentDialog.vue
index 6d8e90bad..96e9a6394 100644
--- a/ui/src/views/document/component/ImportDocumentDialog.vue
+++ b/ui/src/views/document/component/ImportDocumentDialog.vue
@@ -21,7 +21,11 @@
type="textarea"
/>
-