Merge branch 'v2' into knowledge_workflow
Some checks failed
sync2gitee / repo-sync (push) Has been cancelled
Typos Check / Spell Check with Typos (push) Has been cancelled

# Conflicts:
#	apps/application/flow/workflow_manage.py
#	apps/common/utils/tool_code.py
#	ui/src/views/tool/component/ToolListContainer.vue
This commit is contained in:
shaohuzhang1 2025-11-21 18:09:39 +08:00
commit 7e7e786bef
201 changed files with 2576 additions and 1001 deletions

View File

@ -5,7 +5,7 @@ on:
inputs:
dockerImageTag:
description: 'Docker Image Tag'
default: 'v2.0.2'
default: 'v2.0.3'
required: true
architecture:
description: 'Architecture'

View File

@ -7,7 +7,7 @@ on:
inputs:
dockerImageTag:
description: 'Image Tag'
default: 'v2.3.0-dev'
default: 'v2.4.0-dev'
required: true
dockerImageTagWithLatest:
description: '是否发布latest tag正式发版时选择测试版本切勿选择'

3
.gitignore vendored
View File

@ -188,4 +188,5 @@ apps/models_provider/impl/*/icon/
apps/models_provider/impl/tencent_model_provider/credential/stt.py
apps/models_provider/impl/tencent_model_provider/model/stt.py
tmp/
config.yml
config.yml
.SANDBOX_BANNED_HOSTS

View File

@ -93,6 +93,7 @@ def event_content(response,
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
else:
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
content_chunk = reasoning._normalize_content(content_chunk)
all_text += content_chunk
if reasoning_content_chunk is None:
reasoning_content_chunk = ''
@ -191,23 +192,17 @@ class BaseChatStep(IChatStep):
manage, padding_problem_text, chat_user_id, chat_user_type,
no_references_setting,
model_setting,
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
mcp_output_enable)
else:
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
paragraph_list,
manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
model_setting,
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
mcp_output_enable)
def get_details(self, manage, **kwargs):
# 删除临时生成的MCP代码文件
if self.context.get('execute_ids'):
executor = ToolExecutor(CONFIG.get('SANDBOX'))
# 清理工具代码文件,延时删除,避免文件被占用
for tool_id in self.context.get('execute_ids'):
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
if os.path.exists(code_path):
os.remove(code_path)
return {
'step_type': 'chat_step',
'run_time': self.context['run_time'],
@ -254,7 +249,6 @@ class BaseChatStep(IChatStep):
if tool_enable:
if tool_ids and len(tool_ids) > 0: # 如果有工具ID则将其转换为MCP
self.context['tool_ids'] = tool_ids
self.context['execute_ids'] = []
for tool_id in tool_ids:
tool = QuerySet(Tool).filter(id=tool_id).first()
if tool is None or tool.is_active is False:
@ -264,9 +258,8 @@ class BaseChatStep(IChatStep):
params = json.loads(rsa_long_decrypt(tool.init_params))
else:
params = {}
_id, tool_config = executor.get_tool_mcp_config(tool.code, params)
tool_config = executor.get_tool_mcp_config(tool.code, params)
self.context['execute_ids'].append(_id)
mcp_servers_config[str(tool.id)] = tool_config
if len(mcp_servers_config) > 0:
@ -274,7 +267,6 @@ class BaseChatStep(IChatStep):
return None
def get_stream_result(self, message_list: List[BaseMessage],
chat_model: BaseChatModel = None,
paragraph_list=None,
@ -304,7 +296,8 @@ class BaseChatStep(IChatStep):
else:
# 处理 MCP 请求
mcp_result = self._handle_mcp_request(
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model, message_list,
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model,
message_list,
)
if mcp_result:
return mcp_result, True
@ -329,7 +322,8 @@ class BaseChatStep(IChatStep):
tool_ids=None,
mcp_output_enable=True):
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
no_references_setting, problem_text, mcp_enable, mcp_tool_ids,
mcp_servers, mcp_source, tool_enable, tool_ids,
mcp_output_enable)
chat_record_id = uuid.uuid7()
r = StreamingHttpResponse(
@ -404,7 +398,9 @@ class BaseChatStep(IChatStep):
# 调用模型
try:
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
no_references_setting, problem_text, mcp_enable,
mcp_tool_ids, mcp_servers, mcp_source, tool_enable,
tool_ids, mcp_output_enable)
if is_ai_chat:
request_token = chat_model.get_num_tokens_from_messages(message_list)
response_token = chat_model.get_num_tokens(chat_result.content)
@ -416,10 +412,10 @@ class BaseChatStep(IChatStep):
reasoning_result_end = reasoning.get_end_reasoning_content()
content = reasoning_result.get('content') + reasoning_result_end.get('content')
if 'reasoning_content' in chat_result.response_metadata:
reasoning_content = chat_result.response_metadata.get('reasoning_content', '')
reasoning_content = (chat_result.response_metadata.get('reasoning_content', '') or '')
else:
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get(
'reasoning_content')
reasoning_content = (reasoning_result.get('reasoning_content') or "") + (reasoning_result_end.get(
'reasoning_content') or "")
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
content, manage, self, padding_problem_text,
reasoning_content=reasoning_content)

View File

@ -95,6 +95,7 @@ class WorkFlowPostHandler:
application_public_access_client.access_num = application_public_access_client.access_num + 1
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
application_public_access_client.save()
self.chat_info = None
class KnowledgeWorkflowPostHandler(WorkFlowPostHandler):

View File

@ -108,9 +108,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
content = reasoning_result.get('content') + reasoning_result_end.get('content')
meta = {**response.response_metadata, **response.additional_kwargs}
if 'reasoning_content' in meta:
reasoning_content = meta.get('reasoning_content', '')
reasoning_content = (meta.get('reasoning_content', '') or '')
else:
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content')
reasoning_content = (reasoning_result.get('reasoning_content') or '') + (reasoning_result_end.get('reasoning_content') or '')
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
@ -233,7 +233,6 @@ class BaseChatNode(IChatNode):
if tool_enable:
if tool_ids and len(tool_ids) > 0: # 如果有工具ID则将其转换为MCP
self.context['tool_ids'] = tool_ids
self.context['execute_ids'] = []
for tool_id in tool_ids:
tool = QuerySet(Tool).filter(id=tool_id).first()
if not tool.is_active:
@ -243,9 +242,8 @@ class BaseChatNode(IChatNode):
params = json.loads(rsa_long_decrypt(tool.init_params))
else:
params = {}
_id, tool_config = executor.get_tool_mcp_config(tool.code, params)
tool_config = executor.get_tool_mcp_config(tool.code, params)
self.context['execute_ids'].append(_id)
mcp_servers_config[str(tool.id)] = tool_config
if len(mcp_servers_config) > 0:
@ -307,14 +305,6 @@ class BaseChatNode(IChatNode):
return result
def get_details(self, index: int, **kwargs):
# 删除临时生成的MCP代码文件
if self.context.get('execute_ids'):
executor = ToolExecutor(CONFIG.get('SANDBOX'))
# 清理工具代码文件,延时删除,避免文件被占用
for tool_id in self.context.get('execute_ids'):
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
if os.path.exists(code_path):
os.remove(code_path)
return {
'name': self.node.properties.get('stepName'),
"index": index,

View File

@ -74,7 +74,7 @@ class BaseImageToVideoNode(IImageToVideoNode):
def get_file_base64(self, image_url):
try:
if isinstance(image_url, list):
image_url = image_url[0].get('file_id')
image_url = image_url[0].get('file_id') if 'file_id' in image_url[0] else image_url[0].get('url')
if isinstance(image_url, str) and not image_url.startswith('http'):
file = QuerySet(File).filter(id=image_url).first()
file_bytes = file.get_bytes()

View File

@ -131,11 +131,18 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
file_id_list = [image.get('file_id') for image in image_list]
file_id_list = []
url_list = []
for image in image_list:
if 'file_id' in image:
file_id_list.append(image.get('file_id'))
elif 'url' in image:
url_list.append(image.get('url'))
return HumanMessage(content=[
{'type': 'text', 'text': data['question']},
*[{'type': 'image_url', 'image_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list]
*[{'type': 'image_url', 'image_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list],
*[{'type': 'image_url', 'image_url': {'url': url}} for url in url_list]
])
return HumanMessage(content=chat_record.problem_text)
@ -155,13 +162,22 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
image_list = data['image_list']
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
image_base64_list = [file_id_to_base64(image.get('file_id')) for image in image_list]
file_id_list = []
url_list = []
for image in image_list:
if 'file_id' in image:
file_id_list.append(image.get('file_id'))
elif 'url' in image:
url_list.append(image.get('url'))
image_base64_list = [file_id_to_base64(file_id) for file_id in file_id_list]
return HumanMessage(
content=[
{'type': 'text', 'text': data['question']},
*[{'type': 'image_url',
'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
base64_image in image_base64_list]
base64_image in image_base64_list],
*[{'type': 'image_url', 'image_url': url} for url in url_list]
])
return HumanMessage(content=chat_record.problem_text)
@ -177,13 +193,17 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
images.append({'type': 'image_url', 'image_url': {'url': image}})
elif image is not None and len(image) > 0:
for img in image:
file_id = img['file_id']
file = QuerySet(File).filter(id=file_id).first()
image_bytes = file.get_bytes()
base64_image = base64.b64encode(image_bytes).decode("utf-8")
image_format = what(None, image_bytes)
images.append(
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
if 'file_id' in img:
file_id = img['file_id']
file = QuerySet(File).filter(id=file_id).first()
image_bytes = file.get_bytes()
base64_image = base64.b64encode(image_bytes).decode("utf-8")
image_format = what(None, image_bytes)
images.append(
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
elif 'url' in img and img['url'].startswith('http'):
images.append(
{'type': 'image_url', 'image_url': {'url': img["url"]}})
return images
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):

View File

@ -226,11 +226,14 @@ class BaseLoopNode(ILoopNode):
def save_context(self, details, workflow_manage):
self.context['loop_context_data'] = details.get('loop_context_data')
self.context['loop_answer_data'] = details.get('loop_answer_data')
self.context['loop_node_data'] = details.get('loop_node_data')
self.context['result'] = details.get('result')
self.context['params'] = details.get('params')
self.context['run_time'] = details.get('run_time')
self.context['index'] = details.get('current_index')
self.context['item'] = details.get('current_item')
for key, value in (details.get('loop_context_data') or {}).items():
self.context[key] = value
self.answer_text = ""
def get_answer_list(self) -> List[Answer] | None:

View File

@ -37,7 +37,6 @@ class BaseTextToVideoNode(ITextToVideoNode):
self.context['dialogue_type'] = dialogue_type
self.context['negative_prompt'] = self.generate_prompt_question(negative_prompt)
video_urls = ttv_model.generate_video(question, negative_prompt)
print('video_urls', video_urls)
# 保存图片
if video_urls is None:
return NodeResult({'answer': gettext('Failed to generate video')}, {})

View File

@ -1,9 +1,7 @@
# coding=utf-8
import base64
import mimetypes
import time
from functools import reduce
from imghdr import what
from typing import List, Dict
from django.db.models import QuerySet
@ -12,7 +10,6 @@ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AI
from application.flow.i_step_node import NodeResult, INode
from application.flow.step_node.video_understand_step_node.i_video_understand_node import IVideoUnderstandNode
from knowledge.models import File
from models_provider.impl.volcanic_engine_model_provider.model.image import get_video_format
from models_provider.tools import get_model_instance_by_model_workspace_id
@ -134,11 +131,17 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode):
# 增加对 None 和空列表的检查
if not video_list or len(video_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
file_id_list = [video.get('file_id') for video in video_list]
file_id_list = []
url_list = []
for image in video_list:
if 'file_id' in image:
file_id_list.append(image.get('file_id'))
elif 'url' in image:
url_list.append(image.get('url'))
return HumanMessage(content=[
{'type': 'text', 'text': data['question']},
*[{'type': 'video_url', 'video_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list]
*[{'type': 'video_url', 'video_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list],
*[{'type': 'video_url', 'video_url': {'url': url}} for url in url_list],
])
return HumanMessage(content=chat_record.problem_text)
@ -158,6 +161,13 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode):
video_list = data['video_list']
if len(video_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
return HumanMessage(content=chat_record.problem_text)
file_id_list = []
url_list = []
for image in video_list:
if 'file_id' in image:
file_id_list.append(image.get('file_id'))
elif 'url' in image:
url_list.append(image.get('url'))
video_base64_list = [file_id_to_base64(video.get('file_id'), video_model) for video in video_list]
return HumanMessage(
content=[
@ -177,11 +187,15 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode):
videos.append({'type': 'video_url', 'video_url': {'url': image}})
elif image is not None and len(image) > 0:
for img in image:
file_id = img['file_id']
file = QuerySet(File).filter(id=file_id).first()
url = video_model.upload_file_and_get_url(file.get_bytes(), file.file_name)
videos.append(
{'type': 'video_url', 'video_url': {'url': url}})
if 'file_id' in img:
file_id = img['file_id']
file = QuerySet(File).filter(id=file_id).first()
url = video_model.upload_file_and_get_url(file.get_bytes(), file.file_name)
videos.append(
{'type': 'video_url', 'video_url': {'url': url}})
elif 'url' in img and img['url'].startswith('http'):
videos.append(
{'type': 'video_url', 'video_url': {'url': img['url']}})
return videos
def generate_message_list(self, video_model, system: str, prompt: str, history_message, video):

View File

@ -8,7 +8,8 @@
"""
import asyncio
import json
import traceback
import queue
import threading
from typing import Iterator
from django.http import StreamingHttpResponse
@ -47,6 +48,21 @@ class Reasoning:
return r
return {'content': '', 'reasoning_content': ''}
def _normalize_content(self, content):
"""将不同类型的内容统一转换为字符串"""
if isinstance(content, str):
return content
elif isinstance(content, list):
# 处理包含多种内容类型的列表
normalized_parts = []
for item in content:
if isinstance(item, dict):
if item.get('type') == 'text':
normalized_parts.append(item.get('text', ''))
return ''.join(normalized_parts)
else:
return str(content)
def get_reasoning_content(self, chunk):
# 如果没有开始思考过程标签那么就全是结果
if self.reasoning_content_start_tag is None or len(self.reasoning_content_start_tag) == 0:
@ -55,6 +71,7 @@ class Reasoning:
# 如果没有结束思考过程标签那么就全部是思考过程
if self.reasoning_content_end_tag is None or len(self.reasoning_content_end_tag) == 0:
return {'content': '', 'reasoning_content': chunk.content}
chunk.content = self._normalize_content(chunk.content)
self.all_content += chunk.content
if not self.reasoning_content_is_start and len(self.all_content) >= self.reasoning_content_start_tag_len:
if self.all_content.startswith(self.reasoning_content_start_tag):
@ -201,17 +218,6 @@ def to_stream_response_simple(stream_event):
r['Cache-Control'] = 'no-cache'
return r
tool_message_template = """
<details>
<summary>
<strong>Called MCP Tool: <em>%s</em></strong>
</summary>
%s
</details>
"""
tool_message_json_template = """
```json
@ -219,42 +225,142 @@ tool_message_json_template = """
```
"""
tool_message_complete_template = """
<details>
<summary>
<strong>Called MCP Tool: <em>%s</em></strong>
</summary>
def generate_tool_message_template(name, context):
if '```' in context:
return tool_message_template % (name, context)
**Input:**
%s
**Output:**
%s
</details>
"""
def generate_tool_message_complete(name, input_content, output_content):
"""生成包含输入和输出的工具消息模版"""
# 格式化输入
if '```' not in input_content:
input_formatted = tool_message_json_template % input_content
else:
return tool_message_template % (name, tool_message_json_template % (context))
input_formatted = input_content
# 格式化输出
if '```' not in output_content:
output_formatted = tool_message_json_template % output_content
else:
output_formatted = output_content
return tool_message_complete_template % (name, input_formatted, output_formatted)
# 全局单例事件循环
_global_loop = None
_loop_thread = None
_loop_lock = threading.Lock()
def get_global_loop():
"""获取全局共享的事件循环"""
global _global_loop, _loop_thread
with _loop_lock:
if _global_loop is None:
_global_loop = asyncio.new_event_loop()
def run_forever():
asyncio.set_event_loop(_global_loop)
_global_loop.run_forever()
_loop_thread = threading.Thread(target=run_forever, daemon=True, name="GlobalAsyncLoop")
_loop_thread.start()
return _global_loop
async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True):
client = MultiServerMCPClient(json.loads(mcp_servers))
tools = await client.get_tools()
agent = create_react_agent(chat_model, tools)
response = agent.astream({"messages": message_list}, stream_mode='messages')
async for chunk in response:
if mcp_output_enable and isinstance(chunk[0], ToolMessage):
content = generate_tool_message_template(chunk[0].name, chunk[0].content)
chunk[0].content = content
yield chunk[0]
if isinstance(chunk[0], AIMessageChunk):
yield chunk[0]
try:
client = MultiServerMCPClient(json.loads(mcp_servers))
tools = await client.get_tools()
agent = create_react_agent(chat_model, tools)
response = agent.astream({"messages": message_list}, stream_mode='messages')
# 用于存储工具调用信息
tool_calls_info = {}
async for chunk in response:
if isinstance(chunk[0], AIMessageChunk):
tool_calls = chunk[0].additional_kwargs.get('tool_calls', [])
for tool_call in tool_calls:
tool_id = tool_call.get('id', '')
if tool_id:
# 保存工具调用的输入
tool_calls_info[tool_id] = {
'name': tool_call.get('function', {}).get('name', ''),
'input': tool_call.get('function', {}).get('arguments', '')
}
yield chunk[0]
if mcp_output_enable and isinstance(chunk[0], ToolMessage):
tool_id = chunk[0].tool_call_id
if tool_id in tool_calls_info:
# 合并输入和输出
tool_info = tool_calls_info[tool_id]
content = generate_tool_message_complete(
tool_info['name'],
tool_info['input'],
chunk[0].content
)
chunk[0].content = content
yield chunk[0]
except ExceptionGroup as eg:
def get_real_error(exc):
if isinstance(exc, ExceptionGroup):
return get_real_error(exc.exceptions[0])
return exc
real_error = get_real_error(eg)
error_msg = f"{type(real_error).__name__}: {str(real_error)}"
raise RuntimeError(error_msg) from None
except Exception as e:
error_msg = f"{type(e).__name__}: {str(e)}"
raise RuntimeError(error_msg) from None
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True):
loop = asyncio.new_event_loop()
try:
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
while True:
try:
chunk = loop.run_until_complete(anext_async(async_gen))
yield chunk
except StopAsyncIteration:
break
except Exception as e:
maxkb_logger.error(f'Exception: {e}', exc_info=True)
finally:
loop.close()
"""使用全局事件循环,不创建新实例"""
result_queue = queue.Queue()
loop = get_global_loop() # 使用共享循环
async def _run():
try:
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
async for chunk in async_gen:
result_queue.put(('data', chunk))
except Exception as e:
maxkb_logger.error(f'Exception: {e}', exc_info=True)
result_queue.put(('error', e))
finally:
result_queue.put(('done', None))
# 在全局循环中调度任务
asyncio.run_coroutine_threadsafe(_run(), loop)
while True:
msg_type, data = result_queue.get()
if msg_type == 'done':
break
if msg_type == 'error':
raise data
yield data
async def anext_async(agen):

View File

@ -9,7 +9,6 @@
import concurrent
import json
import threading
import traceback
from concurrent.futures import ThreadPoolExecutor
from functools import reduce
from typing import List, Dict
@ -26,6 +25,7 @@ from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult,
from application.flow.step_node import get_node
from common.handle.base_to_response import BaseToResponse
from common.handle.impl.response.system_to_response import SystemToResponse
from common.utils.logger import maxkb_logger
executor = ThreadPoolExecutor(max_workers=200)
@ -234,25 +234,62 @@ class WorkflowManage:
非流式响应
@return: 结果
"""
self.run_chain_async(None, None, language)
while self.is_run():
pass
details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
answer_text_list = self.get_answer_text_list()
answer_text = '\n\n'.join(
'\n\n'.join([a.get('content') for a in answer]) for answer in
answer_text_list)
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
self.work_flow_post_handler.handler(self)
return self.base_to_response.to_block_response(self.params['chat_id'],
self.params['chat_record_id'], answer_text, True
, message_tokens, answer_tokens,
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
other_params={'answer_list': answer_list})
try:
self.run_chain_async(None, None, language)
while self.is_run():
pass
details = self.get_runtime_details()
message_tokens = sum([row.get('message_tokens') for row in details.values() if
'message_tokens' in row and row.get('message_tokens') is not None])
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
'answer_tokens' in row and row.get('answer_tokens') is not None])
answer_text_list = self.get_answer_text_list()
answer_text = '\n\n'.join(
'\n\n'.join([a.get('content') for a in answer]) for answer in
answer_text_list)
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
self.work_flow_post_handler.handler(self)
res = self.base_to_response.to_block_response(self.params['chat_id'],
self.params['chat_record_id'], answer_text, True
, message_tokens, answer_tokens,
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
other_params={'answer_list': answer_list})
finally:
self._cleanup()
return res
def _cleanup(self):
"""清理所有对象引用"""
# 清理列表
self.future_list.clear()
self.field_list.clear()
self.global_field_list.clear()
self.chat_field_list.clear()
self.image_list.clear()
self.video_list.clear()
self.document_list.clear()
self.audio_list.clear()
self.other_list.clear()
if hasattr(self, 'node_context'):
self.node_context.clear()
# 清理字典
self.context.clear()
self.chat_context.clear()
self.form_data.clear()
# 清理对象引用
self.node_chunk_manage = None
self.work_flow_post_handler = None
self.flow = None
self.start_node = None
self.current_node = None
self.current_result = None
self.chat_record = None
self.base_to_response = None
self.params = None
self.lock = None
def run_stream(self, current_node, node_result_future, language='zh'):
"""
@ -307,6 +344,7 @@ class WorkflowManage:
'',
[],
'', True, message_tokens, answer_tokens, {})
self._cleanup()
def run_chain_async(self, current_node, node_result_future, language='zh'):
future = executor.submit(self.run_chain_manage, current_node, node_result_future, language)
@ -345,7 +383,7 @@ class WorkflowManage:
current_node, node_result_future)
return result
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
return None
def hand_node_result(self, current_node, node_result_future):
@ -357,7 +395,7 @@ class WorkflowManage:
list(result)
return current_result
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
self.status = 500
current_node.get_write_error_context(e)
self.answer += str(e)
@ -435,9 +473,10 @@ class WorkflowManage:
return current_result
except Exception as e:
# 添加节点
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
chunk = self.base_to_response.to_stream_chunk_response(self.params.get('chat_id'),
self.params.get('chat_record_id'),
self.params.get('chat_id'),
current_node.id,
current_node.up_node_id_list,
'Exception:' + str(e), False, 0, 0,

View File

@ -118,3 +118,37 @@ class ApplicationStatisticsSerializer(serializers.Serializer):
days.append(current_date.strftime('%Y-%m-%d'))
current_date += datetime.timedelta(days=1)
return days
def get_token_usage_statistics(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
start_time = self.get_start_time()
end_time = self.get_end_time()
get_token_usage = native_search(
{'default_sql': QuerySet(model=get_dynamics_model(
{'application_chat.application_id': models.UUIDField(),
'application_chat_record.create_time': models.DateTimeField()})).filter(
**{'application_chat.application_id': self.data.get('application_id'),
'application_chat_record.create_time__gte': start_time,
'application_chat_record.create_time__lte': end_time}
)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'get_token_usage.sql')))
return get_token_usage
def get_top_questions_statistics(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
start_time = self.get_start_time()
end_time = self.get_end_time()
get_top_questions = native_search(
{'default_sql': QuerySet(model=get_dynamics_model(
{'application_chat.application_id': models.UUIDField(),
'application_chat_record.create_time': models.DateTimeField()})).filter(
**{'application_chat.application_id': self.data.get('application_id'),
'application_chat_record.create_time__gte': start_time,
'application_chat_record.create_time__lte': end_time}
)},
select_string=get_file_content(
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'top_questions.sql')))
return get_top_questions

View File

@ -6,7 +6,7 @@
@date2025/6/9 13:42
@desc:
"""
from datetime import datetime
import json
from typing import List
from django.core.cache import cache
@ -20,6 +20,7 @@ from application.models import Application, ChatRecord, Chat, ApplicationVersion
from application.serializers.application_chat import ChatCountSerializer
from common.constants.cache_version import Cache_Version
from common.database_model_manage.database_model_manage import DatabaseModelManage
from common.encoder.encoder import SystemEncoder
from common.exception.app_exception import ChatException
from knowledge.models import Document
from models_provider.models import Model
@ -226,10 +227,69 @@ class ChatInfo:
chat_record.save()
ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat()
def to_dict(self):
return {
'chat_id': self.chat_id,
'chat_user_id': self.chat_user_id,
'chat_user_type': self.chat_user_type,
'knowledge_id_list': self.knowledge_id_list,
'exclude_document_id_list': self.exclude_document_id_list,
'application_id': self.application_id,
'chat_record_list': [self.chat_record_to_map(c) for c in self.chat_record_list][-20:],
'debug': self.debug
}
def chat_record_to_map(self, chat_record):
return {'id': chat_record.id,
'chat_id': chat_record.chat_id,
'vote_status': chat_record.vote_status,
'problem_text': chat_record.problem_text,
'answer_text': chat_record.answer_text,
'answer_text_list': chat_record.answer_text_list,
'message_tokens': chat_record.message_tokens,
'answer_tokens': chat_record.answer_tokens,
'const': chat_record.const,
'details': chat_record.details,
'improve_paragraph_id_list': chat_record.improve_paragraph_id_list,
'run_time': chat_record.run_time,
'index': chat_record.index}
@staticmethod
def map_to_chat_record(chat_record_dict):
return ChatRecord(id=chat_record_dict.get('id'),
chat_id=chat_record_dict.get('chat_id'),
vote_status=chat_record_dict.get('vote_status'),
problem_text=chat_record_dict.get('problem_text'),
answer_text=chat_record_dict.get('answer_text'),
answer_text_list=chat_record_dict.get('answer_text_list'),
message_tokens=chat_record_dict.get('message_tokens'),
answer_tokens=chat_record_dict.get('answer_tokens'),
const=chat_record_dict.get('const'),
details=chat_record_dict.get('details'),
improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'),
run_time=chat_record_dict.get('run_time'),
index=chat_record_dict.get('index'), )
def set_cache(self):
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self.to_dict(),
version=Cache_Version.CHAT_INFO.get_version(),
timeout=60 * 30)
@staticmethod
def map_to_chat_info(chat_info_dict):
c = ChatInfo(chat_info_dict.get('chat_id'), chat_info_dict.get('chat_user_id'),
chat_info_dict.get('chat_user_type'), chat_info_dict.get('knowledge_id_list'),
chat_info_dict.get('exclude_document_id_list'),
chat_info_dict.get('application_id'),
debug=chat_info_dict.get('debug'))
c.chat_record_list = [ChatInfo.map_to_chat_record(c_r) for c_r in chat_info_dict.get('chat_record_list')]
return c
@staticmethod
def get_cache(chat_id):
return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version())
chat_info_dict = cache.get(Cache_Version.CHAT.get_key(key=chat_id),
version=Cache_Version.CHAT_INFO.get_version())
if chat_info_dict:
return ChatInfo.map_to_chat_info(chat_info_dict)
return None

View File

@ -0,0 +1,12 @@
SELECT
SUM(application_chat_record.message_tokens + application_chat_record.answer_tokens) as "token_usage",
COALESCE(application_chat.asker->>'username', '游客') as "username"
FROM
application_chat_record application_chat_record
LEFT JOIN application_chat application_chat ON application_chat."id" = application_chat_record.chat_id
${default_sql}
GROUP BY
COALESCE(application_chat.asker->>'username', '游客')
ORDER BY
"token_usage" DESC

View File

@ -0,0 +1,11 @@
SELECT COUNT(application_chat_record."id") AS chat_record_count,
COALESCE(application_chat.asker ->>'username', '游客') AS username
FROM application_chat_record application_chat_record
LEFT JOIN application_chat application_chat ON application_chat."id" = application_chat_record.chat_id
${default_sql}
GROUP BY
COALESCE (application_chat.asker->>'username', '游客')
ORDER BY
chat_record_count DESC,
username ASC

View File

@ -13,6 +13,8 @@ urlpatterns = [
path('workspace/<str:workspace_id>/application/<str:application_id>/publish', views.ApplicationAPI.Publish.as_view()),
path('workspace/<str:workspace_id>/application/<str:application_id>/application_key', views.ApplicationKey.as_view()),
path('workspace/<str:workspace_id>/application/<str:application_id>/application_stats', views.ApplicationStats.as_view()),
path('workspace/<str:workspace_id>/application/<str:application_id>/application_token_usage', views.ApplicationStats.TokenUsageStatistics.as_view()),
path('workspace/<str:workspace_id>/application/<str:application_id>/top_questions', views.ApplicationStats.TopQuestionsStatistics.as_view()),
path('workspace/<str:workspace_id>/application/<str:application_id>/application_key/<str:api_key_id>', views.ApplicationKey.Operate.as_view()),
path('workspace/<str:workspace_id>/application/<str:application_id>/export', views.ApplicationAPI.Export.as_view()),
path('workspace/<str:workspace_id>/application/<str:application_id>/application_version', views.ApplicationVersionView.as_view()),

View File

@ -125,8 +125,8 @@ class OpenView(APIView):
responses=None,
tags=[_('Application')] # type: ignore
)
@has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(),
PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(),
@has_permissions(PermissionConstants.APPLICATION_READ.get_workspace_application_permission(),
PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(),
ViewPermission([RoleConstants.USER.get_workspace_role()],
[PermissionConstants.APPLICATION.get_workspace_application_permission()],
CompareConstants.AND),
@ -167,8 +167,8 @@ class PromptGenerateView(APIView):
responses=None,
tags=[_('Application')] # type: ignore
)
@has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(),
PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(),
@has_permissions(PermissionConstants.APPLICATION_READ.get_workspace_application_permission(),
PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(),
ViewPermission([RoleConstants.USER.get_workspace_role()],
[PermissionConstants.APPLICATION.get_workspace_application_permission()],
CompareConstants.AND),

View File

@ -93,8 +93,8 @@ class ApplicationChatRecordOperateAPI(APIView):
)
@has_permissions(PermissionConstants.APPLICATION_CHAT_LOG_READ.get_workspace_application_permission(),
PermissionConstants.APPLICATION_CHAT_LOG_READ.get_workspace_permission_workspace_manage_role(),
PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(),
PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(),
PermissionConstants.APPLICATION_READ.get_workspace_application_permission(),
PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(),
ViewPermission([RoleConstants.USER.get_workspace_role()],
[PermissionConstants.APPLICATION.get_workspace_application_permission()],
CompareConstants.AND),

View File

@ -46,3 +46,58 @@ class ApplicationStats(APIView):
'end_time': request.query_params.get(
'end_time')
}).get_chat_record_aggregate_trend())
class TokenUsageStatistics(APIView):
authentication_classes = [TokenAuth]
# 应用的token使用统计 根据人的使用数排序
@extend_schema(
methods=['GET'],
description=_('Application token usage statistics'),
summary=_('Application token usage statistics'),
operation_id=_('Application token usage statistics'), # type: ignore
parameters=ApplicationStatsAPI.get_parameters(),
responses=ApplicationStatsAPI.get_response(),
tags=[_('Application')] # type: ignore
)
@has_permissions(PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_application_permission(),
PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_permission_workspace_manage_role(),
ViewPermission([RoleConstants.USER.get_workspace_role()],
[PermissionConstants.APPLICATION.get_workspace_application_permission()],
CompareConstants.AND),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role())
def get(self, request: Request, workspace_id: str, application_id: str):
return result.success(
ApplicationStatisticsSerializer(data={'application_id': application_id, 'workspace_id': workspace_id,
'start_time': request.query_params.get(
'start_time'),
'end_time': request.query_params.get(
'end_time')
}).get_token_usage_statistics())
class TopQuestionsStatistics(APIView):
authentication_classes = [TokenAuth]
# 应用的top问题统计
@extend_schema(
methods=['GET'],
description=_('Application top question statistics'),
summary=_('Application top question statistics'),
operation_id=_('Application top question statistics'), # type: ignore
parameters=ApplicationStatsAPI.get_parameters(),
responses=ApplicationStatsAPI.get_response(),
tags=[_('Application')] # type: ignore
)
@has_permissions(PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_application_permission(),
PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_permission_workspace_manage_role(),
ViewPermission([RoleConstants.USER.get_workspace_role()],
[PermissionConstants.APPLICATION.get_workspace_application_permission()],
CompareConstants.AND),
RoleConstants.WORKSPACE_MANAGE.get_workspace_role())
def get(self, request: Request, workspace_id: str, application_id: str):
return result.success(
ApplicationStatisticsSerializer(data={'application_id': application_id, 'workspace_id': workspace_id,
'start_time': request.query_params.get(
'start_time'),
'end_time': request.query_params.get(
'end_time')
}).get_top_questions_statistics())

View File

@ -6,7 +6,6 @@
@date2023/9/4 11:16
@desc: 认证类
"""
import traceback
from importlib import import_module
from django.conf import settings
@ -18,6 +17,7 @@ from rest_framework.authentication import TokenAuthentication
from common.exception.app_exception import AppAuthenticationFailed, AppEmbedIdentityFailed, AppChatNumOutOfBoundsFailed, \
AppApiException
from common.utils.logger import maxkb_logger
token_cache = cache.caches['default']
@ -88,7 +88,7 @@ class TokenAuth(TokenAuthentication):
return handle.handle(request, token, token_details.get_token_details)
raise AppAuthenticationFailed(1002, _('Authentication information is incorrect! illegal user'))
except Exception as e:
traceback.print_stack()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed) or isinstance(e,
AppApiException):
raise e

View File

@ -45,7 +45,7 @@ class ChatAuthentication:
value = json.dumps(self.to_dict())
authentication = encrypt(value)
cache_key = hashlib.sha256(authentication.encode()).hexdigest()
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.value, timeout=60 * 60 * 2)
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.get_version(), timeout=60 * 60 * 2)
return authentication
@staticmethod

View File

@ -7,18 +7,24 @@
@desc:
"""
import os
class MKTokenizer:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def encode(self, text):
return self.tokenizer.encode(text).ids
class TokenizerManage:
tokenizer = None
@staticmethod
def get_tokenizer():
from transformers import BertTokenizer
if TokenizerManage.tokenizer is None:
TokenizerManage.tokenizer = BertTokenizer.from_pretrained(
'bert-base-cased',
cache_dir="/opt/maxkb-app/model/tokenizer",
local_files_only=True,
resume_download=False,
force_download=False)
return TokenizerManage.tokenizer
from tokenizers import Tokenizer
# 创建Tokenizer
model_path = os.path.join("/opt/maxkb-app", "model", "tokenizer", "models--bert-base-cased")
with open(f"{model_path}/refs/main", encoding="utf-8") as f: snapshot = f.read()
TokenizerManage.tokenizer = Tokenizer.from_file(f"{model_path}/snapshots/{snapshot}/tokenizer.json")
return MKTokenizer(TokenizerManage.tokenizer)

View File

@ -30,6 +30,8 @@ class Cache_Version(Enum):
# 对话
CHAT = "CHAT", lambda key: key
CHAT_INFO = "CHAT_INFO", lambda key: key
CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key
# 应用API KEY

View File

@ -77,5 +77,5 @@ class HTMLSplitHandle(BaseSplitHandle):
content = buffer.decode(encoding)
return html2text(content)
except BaseException as e:
traceback.print_exception(e)
maxkb_logger.error(f'Exception: {e}', exc_info=True)
return f'{e}'

View File

@ -11,7 +11,6 @@ class Services(TextChoices):
gunicorn = 'gunicorn', 'gunicorn'
celery_default = 'celery_default', 'celery_default'
local_model = 'local_model', 'local_model'
scheduler = 'scheduler', 'scheduler'
web = 'web', 'web'
celery = 'celery', 'celery'
celery_model = 'celery_model', 'celery_model'
@ -25,7 +24,6 @@ class Services(TextChoices):
cls.gunicorn.value: services.GunicornService,
cls.celery_default: services.CeleryDefaultService,
cls.local_model: services.GunicornLocalModelService,
cls.scheduler: services.SchedulerService,
}
return services_map.get(name)
@ -41,13 +39,10 @@ class Services(TextChoices):
def task_services(cls):
return cls.celery_services()
@classmethod
def scheduler_services(cls):
return [cls.scheduler]
@classmethod
def all_services(cls):
return cls.web_services() + cls.task_services() + cls.scheduler_services()
return cls.web_services() + cls.task_services()
@classmethod
def export_services_values(cls):
@ -101,7 +96,7 @@ class BaseActionCommand(BaseCommand):
)
parser.add_argument('-d', '--daemon', nargs="?", const=True)
parser.add_argument('-w', '--worker', type=int, nargs="?",
default=2 if os.cpu_count() > 6 else math.floor(os.cpu_count() / 2))
default=3 if os.cpu_count() > 6 else max(1, math.floor(os.cpu_count() / 2)))
parser.add_argument('-f', '--force', nargs="?", const=True)
def initial_util(self, *args, **options):

View File

@ -18,14 +18,17 @@ class GunicornService(BaseService):
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
bind = f'{HTTP_HOST}:{HTTP_PORT}'
max_requests = 10240 if int(self.worker) > 1 else 0
cmd = [
'gunicorn', 'maxkb.wsgi:application',
'-b', bind,
'-k', 'gthread',
'--threads', '200',
'-w', str(self.worker),
'--max-requests', '10240',
'--max-requests', str(max_requests),
'--max-requests-jitter', '2048',
'--timeout', '0',
'--graceful-timeout', '0',
'--access-logformat', log_format,
'--access-logfile', '/dev/null',
'--error-logfile', '-'

View File

@ -27,14 +27,17 @@ class GunicornLocalModelService(BaseService):
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
worker = CONFIG.get("LOCAL_MODEL_HOST_WORKER", 1)
max_requests = 10240 if int(worker) > 1 else 0
cmd = [
'gunicorn', 'maxkb.wsgi:application',
'-b', bind,
'-k', 'gthread',
'--threads', '200',
'-w', str(worker),
'--max-requests', '10240',
'--max-requests', str(max_requests),
'--max-requests-jitter', '2048',
'--timeout', '0',
'--graceful-timeout', '0',
'--access-logformat', log_format,
'--access-logfile', '/dev/null',
'--error-logfile', '-'

View File

@ -18,14 +18,17 @@ class SchedulerService(BaseService):
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
bind = f'127.0.0.1:6060'
max_requests = 10240 if int(self.worker) > 1 else 0
cmd = [
'gunicorn', 'maxkb.wsgi:application',
'-b', bind,
'-k', 'gthread',
'--threads', '200',
'-w', str(self.worker),
'--max-requests', '10240',
'--max-requests', str(max_requests),
'--max-requests-jitter', '2048',
'--timeout', '0',
'--graceful-timeout', '0',
'--access-logformat', log_format,
'--access-logfile', '/dev/null',
'--error-logfile', '-'

View File

@ -10,6 +10,8 @@ from django.db import models
class AppModelMixin(models.Model):
objects = models.Manager()
create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True, db_index=True)
update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True, db_index=True)

View File

@ -8,6 +8,7 @@
"""
import base64
import threading
from functools import lru_cache
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
from Crypto.PublicKey import RSA
@ -70,7 +71,7 @@ def encrypt(msg, public_key: str | None = None):
"""
if public_key is None:
public_key = get_key_pair().get('key')
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
cipher = _get_encrypt_cipher(public_key)
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
return base64.b64encode(encrypt_msg).decode()
@ -84,56 +85,69 @@ def decrypt(msg, pri_key: str | None = None):
"""
if pri_key is None:
pri_key = get_key_pair().get('value')
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
cipher = _get_cipher(pri_key)
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
return decrypt_data.decode("utf-8")
@lru_cache(maxsize=2)
def _get_encrypt_cipher(public_key: str):
"""缓存加密 cipher 对象"""
return PKCS1_cipher.new(RSA.importKey(extern_key=public_key, passphrase=secret_code))
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
"""
超长文本加密
:param message: 需要加密的字符串
:param public_key 公钥
:param length: 1024bit的证书用100 2048bit的证书用 200
:param length: 1024bit的证书用100, 2048bit的证书用 200
:return: 加密后的数据
"""
# 读取公钥
if public_key is None:
public_key = get_key_pair().get('key')
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
passphrase=secret_code))
# 处理Plaintext is too long. 分段加密
cipher = _get_encrypt_cipher(public_key)
if len(message) <= length:
# 对编码的数据进行加密并通过base64进行编码
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
else:
rsa_text = []
# 对编码后的数据进行切片,原因:加密长度不能过长
for i in range(0, len(message), length):
cont = message[i:i + length]
# 对切片后的数据进行加密并新增到text后面
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
# 加密完进行拼接
cipher_text = b''.join(rsa_text)
# base64进行编码
result = base64.b64encode(cipher_text)
return result.decode()
@lru_cache(maxsize=2)
def _get_cipher(pri_key: str):
"""缓存 cipher 对象,避免重复创建"""
return PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
"""
超长文本解密默认不加密
超长文本解密,优化内存使用
:param message: 需要解密的数据
:param pri_key: 秘钥
:param length : 1024bit的证书用1282048bit证书用256位
:param length : 1024bit的证书用128,2048bit证书用256位
:return: 解密后的数据
"""
if pri_key is None:
pri_key = get_key_pair().get('value')
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
cipher = _get_cipher(pri_key)
base64_de = base64.b64decode(message)
res = []
# 使用 bytearray 减少内存分配
result = bytearray()
for i in range(0, len(base64_de), length):
res.append(cipher.decrypt(base64_de[i:i + length], 0))
return b"".join(res).decode()
result.extend(cipher.decrypt(base64_de[i:i + length], 0))
return result.decode()

View File

@ -1,51 +1,77 @@
# coding=utf-8
import ast
import base64
import gzip
import json
import os
import socket
import subprocess
import sys
from textwrap import dedent
import socket
import uuid_utils.compat as uuid
from common.utils.logger import maxkb_logger
from django.utils.translation import gettext_lazy as _
from maxkb.const import BASE_DIR, CONFIG
from maxkb.const import PROJECT_DIR
from textwrap import dedent
python_directory = sys.executable
class ToolExecutor:
def __init__(self, sandbox=False):
self.sandbox = sandbox
if sandbox:
self.sandbox_path = '/opt/maxkb-app/sandbox'
self.sandbox_path = CONFIG.get("SANDBOX_HOME", '/opt/maxkb-app/sandbox')
self.user = 'sandbox'
else:
self.sandbox_path = os.path.join(PROJECT_DIR, 'data', 'sandbox')
self.user = None
self._createdir()
if self.sandbox:
os.system(f"chown -R {self.user}:root {self.sandbox_path}")
self.banned_keywords = CONFIG.get("SANDBOX_PYTHON_BANNED_KEYWORDS", 'nothing_is_banned').split(',');
banned_hosts = CONFIG.get("SANDBOX_PYTHON_BANNED_HOSTS", '').strip()
self.sandbox_so_path = f'{self.sandbox_path}/sandbox.so'
try:
if banned_hosts:
hostname = socket.gethostname()
local_ip = socket.gethostbyname(hostname)
banned_hosts = f"{banned_hosts},{hostname},{local_ip}"
except Exception:
pass
self.banned_hosts = banned_hosts
self._init_dir()
except Exception as e:
# 本机忽略异常,容器内不忽略
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if self.sandbox:
raise e
def _createdir(self):
old_mask = os.umask(0o077)
def _init_dir(self):
try:
os.makedirs(self.sandbox_path, 0o700, exist_ok=True)
os.makedirs(os.path.join(self.sandbox_path, 'execute'), 0o700, exist_ok=True)
os.makedirs(os.path.join(self.sandbox_path, 'result'), 0o700, exist_ok=True)
finally:
os.umask(old_mask)
# 只初始化一次
fd = os.open(os.path.join(PROJECT_DIR, 'tmp', 'tool_executor_init_dir.lock'),
os.O_CREAT | os.O_EXCL | os.O_WRONLY)
os.close(fd)
except FileExistsError:
# 文件已存在 → 已初始化过
return
maxkb_logger.debug("init dir")
if self.sandbox:
try:
os.system("chmod -R g-rwx /dev/shm /dev/mqueue")
os.system("chmod o-rwx /run/postgresql")
except Exception as e:
maxkb_logger.warning(f'Exception: {e}', exc_info=True)
pass
if CONFIG.get("SANDBOX_TMP_DIR_ENABLED", '0') == "1":
tmp_dir_path = os.path.join(self.sandbox_path, 'tmp')
os.makedirs(tmp_dir_path, 0o700, exist_ok=True)
os.system(f"chown -R {self.user}:root {tmp_dir_path}")
if os.path.exists(self.sandbox_so_path):
os.chmod(self.sandbox_so_path, 0o440)
# 初始化host黑名单
banned_hosts_file_path = f'{self.sandbox_path}/.SANDBOX_BANNED_HOSTS'
if os.path.exists(banned_hosts_file_path):
os.remove(banned_hosts_file_path)
banned_hosts = CONFIG.get("SANDBOX_PYTHON_BANNED_HOSTS", '').strip()
if banned_hosts:
hostname = socket.gethostname()
local_ip = socket.gethostbyname(hostname)
banned_hosts = f"{banned_hosts},{hostname},{local_ip}"
with open(banned_hosts_file_path, "w") as f:
f.write(banned_hosts)
os.chmod(banned_hosts_file_path, 0o440)
def exec_code(self, code_str, keywords, function_name=None):
self.validate_banned_keywords(code_str)
@ -57,9 +83,7 @@ class ToolExecutor:
python_paths = CONFIG.get_sandbox_python_package_paths().split(',')
_exec_code = f"""
try:
import os
import sys
import json
import sys, json, base64, builtins
path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps']
sys.path = [p for p in sys.path if p not in path_to_exclude]
sys.path += {python_paths}
@ -71,21 +95,21 @@ try:
for local in locals_v:
globals_v[local] = locals_v[local]
exec_result=f(**keywords)
with open({result_path!a}, 'w') as file:
file.write(json.dumps({success}, default=str))
builtins.print("\\n{_id}:"+base64.b64encode(json.dumps({success}, default=str).encode()).decode())
except Exception as e:
with open({result_path!a}, 'w') as file:
file.write(json.dumps({err}))
builtins.print("\\n{_id}:"+base64.b64encode(json.dumps({err}, default=str).encode()).decode())
"""
if self.sandbox:
subprocess_result = self._exec_sandbox(_exec_code, _id)
subprocess_result = self._exec_sandbox(_exec_code)
else:
subprocess_result = self._exec(_exec_code)
if subprocess_result.returncode == 1:
raise Exception(subprocess_result.stderr)
with open(result_path, 'r') as file:
result = json.loads(file.read())
os.remove(result_path)
lines = subprocess_result.stdout.splitlines()
result_line = [line for line in lines if line.startswith(_id)]
if not result_line:
raise Exception("No result found.")
result = json.loads(base64.b64decode(result_line[-1].split(":", 1)[1]).decode())
if result.get('code') == 200:
return result.get('data')
raise Exception(result.get('msg'))
@ -173,52 +197,45 @@ exec({dedent(code)!a})
"""
def get_tool_mcp_config(self, code, params):
code = self.generate_mcp_server_code(code, params)
_id = uuid.uuid7()
code_path = f'{self.sandbox_path}/execute/{_id}.py'
with open(code_path, 'w') as f:
f.write(code)
_code = self.generate_mcp_server_code(code, params)
maxkb_logger.debug(f"Python code of mcp tool: {_code}")
compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode()
if self.sandbox:
os.system(f"chown {self.user}:root {code_path}")
tool_config = {
'command': 'su',
'args': [
'-s', sys.executable,
'-c', f"exec(open('{code_path}', 'r').read())",
'-c',
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
self.user,
],
'cwd': self.sandbox_path,
'env': {
'LD_PRELOAD': '/opt/maxkb-app/sandbox/sandbox.so',
'SANDBOX_BANNED_HOSTS': self.banned_hosts,
'LD_PRELOAD': self.sandbox_so_path,
},
'transport': 'stdio',
}
else:
tool_config = {
'command': sys.executable,
'args': [code_path],
'args': f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
'transport': 'stdio',
}
return _id, tool_config
return tool_config
def _exec_sandbox(self, _code, _id):
exec_python_file = f'{self.sandbox_path}/execute/{_id}.py'
with open(exec_python_file, 'w') as file:
file.write(_code)
os.system(f"chown {self.user}:root {exec_python_file}")
def _exec_sandbox(self, _code):
kwargs = {'cwd': BASE_DIR}
kwargs['env'] = {
'LD_PRELOAD': '/opt/maxkb-app/sandbox/sandbox.so',
'SANDBOX_BANNED_HOSTS': self.banned_hosts,
'LD_PRELOAD': self.sandbox_so_path,
}
maxkb_logger.debug(f"Sandbox execute code: {_code}")
compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode()
subprocess_result = subprocess.run(
['su', '-s', python_directory, '-c', "exec(open('" + exec_python_file + "').read())", self.user],
['su', '-s', python_directory, '-c',
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
self.user],
text=True,
capture_output=True, **kwargs)
os.remove(exec_python_file)
return subprocess_result
def validate_banned_keywords(self, code_str):

View File

@ -816,7 +816,7 @@ class DocumentSerializers(serializers.Serializer):
@post(post_function=post_embedding)
@transaction.atomic
def save(self, instance: Dict, with_valid=False, **kwargs):
def save(self, instance: Dict, with_valid=True, **kwargs):
if with_valid:
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
self.is_valid(raise_exception=True)

View File

@ -7,7 +7,6 @@
@desc:
"""
import logging
import traceback
from typing import List

View File

@ -542,9 +542,9 @@ class DocumentView(APIView):
@extend_schema(
methods=['PUT'],
summary=_('Batch generate related documents'),
description=_('Batch generate related documents'),
operation_id=_('Batch generate related documents'), # type: ignore
summary=_('Batch generate related problems'),
description=_('Batch generate related problems'),
operation_id=_('Batch generate related problems'), # type: ignore
request=BatchGenerateRelatedAPI.get_request(),
parameters=BatchGenerateRelatedAPI.get_parameters(),
responses=BatchGenerateRelatedAPI.get_response(),
@ -560,7 +560,7 @@ class DocumentView(APIView):
[PermissionConstants.KNOWLEDGE.get_workspace_knowledge_permission()], CompareConstants.AND),
)
@log(
menu='document', operate="Batch generate related documents",
menu='document', operate="Batch generate related problems",
get_operation_object=lambda r, keywords: get_knowledge_document_operation_object(
get_knowledge_operation_object(keywords.get('knowledge_id')),
get_document_operation_object_batch(r.data.get('document_id_list'))

View File

@ -74,7 +74,6 @@ class ModelManage:
def get_local_model(model, **kwargs):
# system_setting = QuerySet(SystemSetting).filter(type=1).first()
return LocalModelProvider().get_model(model.model_type, model.model_name,
json.loads(
rsa_long_decrypt(model.credential)),
@ -111,6 +110,21 @@ class CompressDocuments(serializers.Serializer):
query = serializers.CharField(required=True, label=_('query'))
class ValidateModelSerializers(serializers.Serializer):
model_name = serializers.CharField(required=True, label=_('model_name'))
model_type = serializers.CharField(required=True, label=_('model_type'))
model_credential = serializers.DictField(required=True, label="credential")
def validate_model(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'),
self.data.get('model_credential'), model_params={},
raise_exception=True)
class ModelApplySerializers(serializers.Serializer):
model_id = serializers.UUIDField(required=True, label=_('model id'))
@ -138,3 +152,9 @@ class ModelApplySerializers(serializers.Serializer):
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
instance.get('documents')], instance.get('query'))]
def unload(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
ModelManage.delete_key(self.data.get('model_id'))
return True

View File

@ -7,7 +7,9 @@ from . import views
app_name = "local_model"
# @formatter:off
urlpatterns = [
path('model/validate', views.LocalModelApply.Validate.as_view()),
path('model/<str:model_id>/embed_documents', views.LocalModelApply.EmbedDocuments.as_view()),
path('model/<str:model_id>/embed_query', views.LocalModelApply.EmbedQuery.as_view()),
path('model/<str:model_id>/compress_documents', views.LocalModelApply.CompressDocuments.as_view()),
path('model/<str:model_id>/unload', views.LocalModelApply.Unload.as_view()),
]

View File

@ -11,7 +11,7 @@ from urllib.request import Request
from rest_framework.views import APIView
from common.result import result
from local_model.serializers.model_apply_serializers import ModelApplySerializers
from local_model.serializers.model_apply_serializers import ModelApplySerializers, ValidateModelSerializers
class LocalModelApply(APIView):
@ -32,3 +32,12 @@ class LocalModelApply(APIView):
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
class Unload(APIView):
def post(self, request: Request, model_id):
return result.success(
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
class Validate(APIView):
def post(self, request: Request):
return result.success(ValidateModelSerializers(data=request.data).validate_model())

View File

@ -2945,7 +2945,7 @@ msgstr ""
#: apps/knowledge/views/document.py:466 apps/knowledge/views/document.py:467
#: apps/knowledge/views/document.py:468
msgid "Batch generate related documents"
msgid "Batch generate related problems"
msgstr ""
#: apps/knowledge/views/document.py:496 apps/knowledge/views/document.py:497
@ -5126,7 +5126,7 @@ msgstr ""
#: apps/resource_manage/views/document.py:261
#: apps/resource_manage/views/document.py:262
#: apps/resource_manage/views/document.py:263
msgid "Batch generate related system document"
msgid "Batch generate related system problems"
msgstr ""
#: apps/resource_manage/views/document.py:290
@ -5946,7 +5946,7 @@ msgstr ""
#: apps/shared/views/shared_document.py:262
#: apps/shared/views/shared_document.py:263
#: apps/shared/views/shared_document.py:264
msgid "Batch generate related shared documents"
msgid "Batch generate related shared problems"
msgstr ""
#: apps/shared/views/shared_document.py:291
@ -6671,7 +6671,7 @@ msgid "Phone"
msgstr ""
#: apps/users/serializers/user.py:120 apps/xpack/serializers/chat_user.py:68
msgid "Username must be 4-20 characters long"
msgid "Username must be 4-64 characters long"
msgstr ""
#: apps/users/serializers/user.py:133 apps/users/serializers/user.py:298
@ -8766,4 +8766,10 @@ msgid "Non-existent id"
msgstr ""
msgid "No permission for the target folder"
msgstr ""
msgid "Application token usage statistics"
msgstr ""
msgid "Application top question statistics"
msgstr ""

View File

@ -2952,8 +2952,8 @@ msgstr "批量刷新文档向量库"
#: apps/knowledge/views/document.py:466 apps/knowledge/views/document.py:467
#: apps/knowledge/views/document.py:468
msgid "Batch generate related documents"
msgstr "批量生成相关文档"
msgid "Batch generate related problems"
msgstr "批量生成相关问题"
#: apps/knowledge/views/document.py:496 apps/knowledge/views/document.py:497
#: apps/knowledge/views/document.py:498
@ -5251,8 +5251,8 @@ msgstr "批量刷新文档向量库"
#: apps/resource_manage/views/document.py:261
#: apps/resource_manage/views/document.py:262
#: apps/resource_manage/views/document.py:263
msgid "Batch generate related system document"
msgstr "批量生成相关文档"
msgid "Batch generate related system problems"
msgstr "批量生成相关问题"
#: apps/resource_manage/views/document.py:290
#: apps/resource_manage/views/document.py:291
@ -6071,8 +6071,8 @@ msgstr "批量刷新共享文档向量库"
#: apps/shared/views/shared_document.py:262
#: apps/shared/views/shared_document.py:263
#: apps/shared/views/shared_document.py:264
msgid "Batch generate related shared documents"
msgstr "批量生成相关共享文档"
msgid "Batch generate related shared problems"
msgstr "批量生成相关问题"
#: apps/shared/views/shared_document.py:291
#: apps/shared/views/shared_document.py:292
@ -6796,8 +6796,8 @@ msgid "Phone"
msgstr "手机"
#: apps/users/serializers/user.py:120 apps/xpack/serializers/chat_user.py:68
msgid "Username must be 4-20 characters long"
msgstr "用户名必须为4-20个字符"
msgid "Username must be 4-64 characters long"
msgstr "用户名必须为4-64个字符"
#: apps/users/serializers/user.py:133 apps/users/serializers/user.py:298
#: apps/xpack/serializers/chat_user.py:81
@ -8894,3 +8894,8 @@ msgstr "不存在的ID"
msgid "No permission for the target folder"
msgstr "没有目标文件夹的权限"
msgid "Application token usage statistics"
msgstr "应用令牌使用统计"
msgid "Application top question statistics"
msgstr "应用提问次数统计"

View File

@ -2952,8 +2952,8 @@ msgstr "批量刷新文檔向量庫"
#: apps/knowledge/views/document.py:466 apps/knowledge/views/document.py:467
#: apps/knowledge/views/document.py:468
msgid "Batch generate related documents"
msgstr "批量生成相關文檔"
msgid "Batch generate related problems"
msgstr "批量生成相關問題"
#: apps/knowledge/views/document.py:496 apps/knowledge/views/document.py:497
#: apps/knowledge/views/document.py:498
@ -5251,8 +5251,8 @@ msgstr "批量刷新文檔向量庫"
#: apps/resource_manage/views/document.py:261
#: apps/resource_manage/views/document.py:262
#: apps/resource_manage/views/document.py:263
msgid "Batch generate related system document"
msgstr "批量生成相關文檔"
msgid "Batch generate related system problems"
msgstr "批量生成相關問題"
#: apps/resource_manage/views/document.py:290
#: apps/resource_manage/views/document.py:291
@ -6071,8 +6071,8 @@ msgstr "批量刷新共享文檔向量庫"
#: apps/shared/views/shared_document.py:262
#: apps/shared/views/shared_document.py:263
#: apps/shared/views/shared_document.py:264
msgid "Batch generate related shared documents"
msgstr "批量生成相關共享文檔"
msgid "Batch generate related shared problems"
msgstr "批量生成相關問題"
#: apps/shared/views/shared_document.py:291
#: apps/shared/views/shared_document.py:292
@ -6796,8 +6796,8 @@ msgid "Phone"
msgstr "手機"
#: apps/users/serializers/user.py:120 apps/xpack/serializers/chat_user.py:68
msgid "Username must be 4-20 characters long"
msgstr "用戶名必須為4-20個字符"
msgid "Username must be 4-64 characters long"
msgstr "用戶名必須為4-64個字符"
#: apps/users/serializers/user.py:133 apps/users/serializers/user.py:298
#: apps/xpack/serializers/chat_user.py:81
@ -8894,3 +8894,8 @@ msgstr "不存在的ID"
msgid "No permission for the target folder"
msgstr "沒有目標資料夾的權限"
msgid "Application token usage statistics"
msgstr "應用令牌使用統計"
msgid "Application top question statistics"
msgstr "應用提問次數統計"

View File

@ -6,16 +6,36 @@
@date2025/11/5 15:14
@desc:
"""
import builtins
import os
import sys
from django.core.wsgi import get_wsgi_application
class TorchBlocker:
def __init__(self):
self.original_import = builtins.__import__
def __call__(self, name, *args, **kwargs):
if len([True for i in
['torch']
if
i in name.lower()]) > 0:
print(f"Disable package is being imported: 【{name}", file=sys.stderr)
pass
else:
return self.original_import(name, *args, **kwargs)
# 安装导入拦截器
builtins.__import__ = TorchBlocker()
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'maxkb.settings')
os.environ['TIKTOKEN_CACHE_DIR'] = '/opt/maxkb-app/model/tokenizer/openai-tiktoken-cl100k-base'
application = get_wsgi_application()
def post_handler():
from common.database_model_manage.database_model_manage import DatabaseModelManage
from common import event
@ -24,14 +44,5 @@ def post_handler():
DatabaseModelManage.init()
def post_scheduler_handler():
from common import job
job.run()
# 启动后处理函数
post_handler()
# 仅在scheduler中启动定时任务dev local_model celery 不需要
if os.environ.get('ENABLE_SCHEDULER') == '1':
post_scheduler_handler()

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict, Any
from common import forms
@ -7,7 +6,7 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from django.utils.translation import gettext as _
from common.utils.logger import maxkb_logger
class AliyunBaiLianAsrSTTModelCredential(BaseForm, BaseModelCredential):
api_url = forms.TextInputField(_('API URL'), required=True)
@ -41,7 +40,7 @@ class AliyunBaiLianAsrSTTModelCredential(BaseForm, BaseModelCredential):
try:
model = provider.get_model(model_type, model_name, model_credential)
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/10/16 17:01
@desc:
"""
import traceback
from typing import Dict, Any
from django.utils.translation import gettext as _
@ -16,6 +15,7 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
from common.utils.logger import maxkb_logger
class BaiLianEmbeddingModelParams(BaseForm):
dimensions = forms.SingleSelect(
@ -69,7 +69,7 @@ class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
model: AliyunBaiLianEmbedding = provider.get_model(model_type, model_name, model_credential)
model.embed_query(_("Hello"))
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,11 +6,9 @@
@date2024/7/11 18:41
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
from langchain_core.messages import HumanMessage
from common import forms
from common.exception.app_exception import AppApiException
@ -69,7 +67,7 @@ class QwenVLModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth(model_credential.get('api_key'))
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,6 +1,5 @@
# coding=utf-8
import traceback
from typing import Dict, Any
from django.utils.translation import gettext_lazy as _, gettext
@ -9,7 +8,7 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
from common.forms.switch_field import SwitchField
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class QwenModelParams(BaseForm):
"""
@ -85,7 +84,7 @@ class ImageToVideoModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
@ -9,7 +8,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class BaiLianLLMModelParams(BaseForm):
temperature = forms.SliderField(
@ -76,7 +75,7 @@ class BaiLianLLMModelCredential(BaseForm, BaseModelCredential):
else:
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict, Any
from common import forms
@ -7,6 +6,7 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from django.utils.translation import gettext as _
from common.utils.logger import maxkb_logger
class AliyunBaiLianOmiSTTModelParams(BaseForm):
CueWord = forms.TextInputField(
@ -49,7 +49,7 @@ class AliyunBaiLianOmiSTTModelCredential(BaseForm, BaseModelCredential):
try:
model = provider.get_model(model_type, model_name, model_credential)
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,6 +1,5 @@
# coding=utf-8
import traceback
from typing import Dict, Any
from django.utils.translation import gettext as _
@ -10,7 +9,7 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
from common.utils.logger import maxkb_logger
class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
"""
@ -60,7 +59,7 @@ class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
model: AliyunBaiLianReranker = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content=_('Hello'))], _('Hello'))
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,6 +1,5 @@
# coding=utf-8
import traceback
from typing import Dict, Any
from django.utils.translation import gettext as _
@ -9,7 +8,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class AliyunBaiLianSTTModelParams(BaseForm):
sample_rate = forms.SliderField(
@ -68,7 +67,7 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential,**model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,6 +1,5 @@
# coding=utf-8
import traceback
from typing import Dict, Any
from django.utils.translation import gettext_lazy as _, gettext
@ -8,7 +7,7 @@ from django.utils.translation import gettext_lazy as _, gettext
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class QwenModelParams(BaseForm):
"""
@ -110,7 +109,7 @@ class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,6 +1,5 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -8,7 +7,7 @@ from django.utils.translation import gettext_lazy as _, gettext
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class AliyunBaiLianTTSModelGeneralParams(BaseForm):
"""
@ -103,7 +102,7 @@ class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,6 +1,5 @@
# coding=utf-8
import traceback
from typing import Dict, Any
from django.utils.translation import gettext_lazy as _, gettext
@ -9,7 +8,7 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
from common.forms.switch_field import SwitchField
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class QwenModelParams(BaseForm):
"""
@ -87,7 +86,7 @@ class TextToVideoModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -54,7 +54,6 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
'callback': None,
**self.params
}
print(recognition_params)
recognition = Recognition(**recognition_params)

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -54,7 +53,7 @@ class AnthropicImageModelCredential(BaseForm, BaseModelCredential):
for chunk in res:
maxkb_logger.info(chunk)
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/7/11 18:32
@desc:
"""
import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
@ -16,7 +15,7 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from django.utils.translation import gettext_lazy as _, gettext
from common.utils.logger import maxkb_logger
class AnthropicLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
@ -56,7 +55,7 @@ class AnthropicLLMModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,4 +1,3 @@
import traceback
from typing import Dict
from django.utils.translation import gettext as _
@ -8,7 +7,7 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
from common.utils.logger import maxkb_logger
class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
@ -35,7 +34,7 @@ class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
except AppApiException:
raise
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
_('Verification failed, please check whether the parameters are correct: {error}').format(

View File

@ -1,4 +1,3 @@
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -8,7 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import ValidCode, BaseModelCredential
from common.utils.logger import maxkb_logger
class BedrockLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
@ -54,7 +53,7 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
except AppApiException:
raise
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if raise_exception:
raise AppApiException(ValidCode.valid_error.value,
gettext(

View File

@ -6,7 +6,6 @@
@date2024/7/11 17:08
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext as _
@ -15,7 +14,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
@ -36,7 +35,7 @@ class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query(_('Hello'))
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,7 +1,6 @@
# coding=utf-8
import base64
import os
import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
@ -57,7 +56,7 @@ class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential):
for chunk in res:
maxkb_logger.info(chunk)
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/7/11 17:08
@desc:
"""
import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
@ -17,7 +16,7 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from django.utils.translation import gettext_lazy as _, gettext
from common.utils.logger import maxkb_logger
class AzureLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
@ -68,7 +67,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException) or isinstance(e, BadRequestError):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext as _
@ -8,7 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
api_version = forms.TextInputField("API Version", required=True)
@ -32,7 +31,7 @@ class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -8,7 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class AzureOpenAITTIModelParams(BaseForm):
size = forms.SingleSelect(
@ -67,7 +66,7 @@ class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -8,7 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class AzureOpenAITTSModelGeneralParams(BaseForm):
# alloy, echo, fable, onyx, nova, shimmer
@ -50,7 +49,7 @@ class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -7,9 +7,11 @@
@desc:
"""
from typing import List, Dict
from typing import List, Dict, Optional, Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.runnables import RunnableConfig
from langchain_openai import AzureChatOpenAI
from common.config.tokenizer_manage_config import TokenizerManage
@ -36,16 +38,42 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
streaming=True,
)
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
return self.__dict__.get('_last_generation_info')
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
return super().get_num_tokens_from_messages(messages)
return self.get_last_generation_info().get('input_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
return super().get_num_tokens(text)
return self.get_last_generation_info().get('output_tokens', 0)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))
def invoke(
self,
input: LanguageModelInput,
config: Optional[RunnableConfig] = None,
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> BaseMessage:
message = super().invoke(input, config, stop=stop, **kwargs)
if isinstance(message.content, str):
return message
elif isinstance(message.content, list):
# 构造新的响应消息返回
content = message.content
normalized_parts = []
for item in content:
if isinstance(item, dict):
if item.get('type') == 'text':
normalized_parts.append(item.get('text', ''))
message.content = ''.join(normalized_parts)
self.__dict__.setdefault('_last_generation_info', {}).update(message.usage_metadata)
return message

View File

@ -1,4 +1,5 @@
# coding=utf-8
import base64
from concurrent.futures import ThreadPoolExecutor
from requests.exceptions import ConnectTimeout, ReadTimeout
from typing import Dict, Optional, Any, Iterator, cast, Union, Sequence, Callable, Mapping
@ -15,7 +16,7 @@ from langchain_openai import ChatOpenAI
from langchain_openai.chat_models.base import _create_usage_metadata
from common.config.tokenizer_manage_config import TokenizerManage
from common.utils.logger import maxkb_logger
def custom_get_token_ids(text: str):
tokenizer = TokenizerManage.get_tokenizer()
@ -102,13 +103,13 @@ class BaseChatOpenAI(ChatOpenAI):
future = executor.submit(super().get_num_tokens_from_messages, messages, tools)
try:
response = future.result()
print("请求成功(未超时)")
maxkb_logger.info("请求成功(未超时)")
return response
except Exception as e:
if isinstance(e, ReadTimeout):
raise # 继续抛出
else:
print("except:", e)
maxkb_logger.error("except:", e)
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
@ -211,3 +212,20 @@ class BaseChatOpenAI(ChatOpenAI):
self.usage_metadata = chat_result.response_metadata[
'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata
return chat_result
def upload_file_and_get_url(self, file_stream, file_name):
"""上传文件并获取文件URL"""
base64_video = base64.b64encode(file_stream).decode("utf-8")
video_format = get_video_format(file_name)
return f'data:{video_format};base64,{base64_video}'
def get_video_format(file_name):
extension = file_name.split('.')[-1].lower()
format_map = {
'mp4': 'video/mp4',
'avi': 'video/avi',
'mov': 'video/mov',
'wmv': 'video/x-ms-wmv'
}
return format_map.get(extension, 'video/mp4')

View File

@ -6,7 +6,6 @@
@date2024/7/11 17:51
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -16,7 +15,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class DeepSeekLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
@ -56,7 +55,7 @@ class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/7/12 16:45
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext as _
@ -15,7 +14,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
@ -35,7 +34,7 @@ class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query(_('Hello'))
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -53,7 +52,7 @@ class GeminiImageModelCredential(BaseForm, BaseModelCredential):
for chunk in res:
maxkb_logger.info(chunk)
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/7/11 17:57
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -16,7 +15,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class GeminiLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
@ -56,7 +55,7 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext as _
@ -8,7 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
api_key = forms.PasswordInputField('API Key', required=True)
@ -30,7 +29,7 @@ class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -23,6 +23,5 @@ class GeminiImage(MaxKBBaseModel, ChatGoogleGenerativeAI):
return GeminiImage(
model=model_name,
google_api_key=model_credential.get('api_key'),
streaming=True,
**optional_params,
)

View File

@ -6,7 +6,6 @@
@date2024/7/11 18:06
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -16,7 +15,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class KimiLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
@ -56,7 +55,7 @@ class KimiLLMModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file __init__.py.py
@date2025/11/7 14:02
@desc:
"""
import os
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
from .model import *
else:
from .web import *

View File

@ -1,12 +1,11 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file embedding.py
@date2024/7/11 11:06
@Author
@file model.py.py
@date2025/11/7 14:02
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -16,7 +15,7 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
from common.utils.logger import maxkb_logger
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
@ -35,7 +34,7 @@ class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
model.embed_query(gettext('Hello'))
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -0,0 +1,37 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file web.py
@date2025/11/7 14:03
@desc:
"""
from typing import Dict
import requests
from django.utils.translation import gettext_lazy as _
from common import forms
from common.forms import BaseForm
from maxkb.const import CONFIG
from models_provider.base_model_provider import BaseModelCredential
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
prefix = CONFIG.get_admin_path()
res = requests.post(
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate',
json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('message'))
def encryption_dict(self, model: Dict[str, object]):
return model
cache_folder = forms.TextInputField(_('Model catalog'), required=True)

View File

@ -0,0 +1,14 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file __init__.py.py
@date2025/11/7 14:22
@desc:
"""
import os
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
from .model import *
else:
from .web import *

View File

@ -1,12 +1,11 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file reranker.py
@date2024/9/3 14:33
@Author
@file model.py
@date2025/11/7 14:23
@desc:
"""
import traceback
from typing import Dict
from langchain_core.documents import Document
@ -17,7 +16,7 @@ from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.local_model_provider.model.reranker import LocalReranker
from django.utils.translation import gettext_lazy as _, gettext
from common.utils.logger import maxkb_logger
class LocalRerankerCredential(BaseForm, BaseModelCredential):
@ -36,7 +35,7 @@ class LocalRerankerCredential(BaseForm, BaseModelCredential):
model: LocalReranker = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello'))
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -0,0 +1,37 @@
# coding=utf-8
"""
@project: MaxKB
@Author虎虎
@file web.py
@date2025/11/7 14:23
@desc:
"""
from typing import Dict
import requests
from django.utils.translation import gettext_lazy as _
from common import forms
from common.forms import BaseForm
from maxkb.const import CONFIG
from models_provider.base_model_provider import BaseModelCredential
class LocalRerankerCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
raise_exception=False):
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
prefix = CONFIG.get_admin_path()
res = requests.post(
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate',
json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential})
result = res.json()
if result.get('code', 500) == 200:
return result.get('data')
raise Exception(result.get('message'))
def encryption_dict(self, model: Dict[str, object]):
return model
cache_folder = forms.TextInputField(_('Model catalog'), required=True)

View File

@ -32,7 +32,6 @@ class LocalReranker(MaxKBBaseModel, BaseModel, BaseDocumentCompressor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
print('ssss', kwargs.get('model_id', None))
self.model_id = kwargs.get('model_id', None)
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \

View File

@ -6,7 +6,6 @@
@date2024/7/12 16:45
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext as _
@ -15,6 +14,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class OpenAIEmbeddingModelParams(BaseForm):
dimensions = forms.SingleSelect(
@ -53,7 +53,7 @@ class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query(_('Hello'))
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,7 +1,6 @@
# coding=utf-8
import base64
import os
import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
@ -56,7 +55,7 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
for chunk in res:
maxkb_logger.info(chunk)
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/7/11 18:32
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -17,7 +16,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class OpenAILLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
@ -58,7 +57,7 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException) or isinstance(e, BadRequestError):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext as _
@ -8,7 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class OpenAISTTModelParams(BaseForm):
language = forms.TextInputField(
@ -38,7 +37,7 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -8,7 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class OpenAITTIModelParams(BaseForm):
size = forms.SingleSelect(
@ -70,7 +69,7 @@ class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -8,7 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class OpenAITTSModelGeneralParams(BaseForm):
# alloy, echo, fable, onyx, nova, shimmer
@ -49,7 +48,7 @@ class OpenAITTSModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/7/12 16:45
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext as _
@ -15,7 +14,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class RegoloEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
@ -35,7 +34,7 @@ class RegoloEmbeddingCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query(_('Hello'))
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,7 +1,6 @@
# coding=utf-8
import base64
import os
import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
@ -12,7 +11,7 @@ from common.forms import BaseForm, TooltipLabel
from django.utils.translation import gettext_lazy as _, gettext
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class RegoloImageModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
@ -53,10 +52,8 @@ class RegoloImageModelCredential(BaseForm, BaseModelCredential):
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])])
for chunk in res:
print(chunk)
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/7/11 18:32
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -16,7 +15,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class RegoloLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
@ -57,7 +56,7 @@ class RegoloLLMModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -8,7 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class RegoloTTIModelParams(BaseForm):
size = forms.SingleSelect(
@ -68,9 +67,8 @@ class RegoloTextToImageModelCredential(BaseForm, BaseModelCredential):
try:
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
print(res)
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/7/12 16:45
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext as _
@ -15,7 +14,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class SiliconCloudEmbeddingCredential(BaseForm, BaseModelCredential):
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
@ -35,7 +34,7 @@ class SiliconCloudEmbeddingCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential)
model.embed_query(_('Hello'))
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,7 +1,6 @@
# coding=utf-8
import base64
import os
import traceback
from typing import Dict
from langchain_core.messages import HumanMessage
@ -56,7 +55,7 @@ class SiliconCloudImageModelCredential(BaseForm, BaseModelCredential):
for chunk in res:
maxkb_logger.info(chunk)
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/7/11 18:32
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -16,7 +15,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class SiliconCloudLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
@ -57,7 +56,7 @@ class SiliconCloudLLMModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/9/9 17:51
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext as _
@ -17,7 +16,7 @@ from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from models_provider.impl.siliconCloud_model_provider.model.reranker import SiliconCloudReranker
from common.utils.logger import maxkb_logger
class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential):
@ -36,7 +35,7 @@ class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential):
model: SiliconCloudReranker = provider.get_model(model_type, model_name, model_credential)
model.compress_documents([Document(page_content=_('Hello'))], _('Hello'))
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext as _
@ -8,7 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class SiliconCloudSTTModelCredential(BaseForm, BaseModelCredential):
api_base = forms.TextInputField('API URL', required=True)
@ -31,7 +30,7 @@ class SiliconCloudSTTModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential,**model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -8,7 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class SiliconCloudTTIModelParams(BaseForm):
size = forms.SingleSelect(
@ -70,7 +69,7 @@ class SiliconCloudTextToImageModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
res = model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -1,5 +1,4 @@
# coding=utf-8
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -8,6 +7,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class SiliconCloudTTSModelGeneralParams(BaseForm):
# alloy, echo, fable, onyx, nova, shimmer
@ -50,7 +50,7 @@ class SiliconCloudTTSModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.check_auth()
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,7 +6,6 @@
@date2024/7/11 18:32
@desc:
"""
import traceback
from typing import Dict
from django.utils.translation import gettext_lazy as _, gettext
@ -16,7 +15,7 @@ from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm, TooltipLabel
from models_provider.base_model_provider import BaseModelCredential, ValidCode
from common.utils.logger import maxkb_logger
class TencentCloudLLMModelParams(BaseForm):
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
@ -57,7 +56,7 @@ class TencentCloudLLMModelCredential(BaseForm, BaseModelCredential):
model = provider.get_model(model_type, model_name, model_credential, **model_params)
model.invoke([HumanMessage(content=gettext('Hello'))])
except Exception as e:
traceback.print_exc()
maxkb_logger.error(f'Exception: {e}', exc_info=True)
if isinstance(e, AppApiException):
raise e
if raise_exception:

View File

@ -6,9 +6,7 @@
@date2024/4/18 15:28
@desc:
"""
from typing import List, Dict
from langchain_core.messages import BaseMessage, get_buffer_string
from typing import Dict
from common.config.tokenizer_manage_config import TokenizerManage
from models_provider.base_model_provider import MaxKBBaseModel

Some files were not shown because too many files have changed in this diff Show More