mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
Merge branch 'v2' into knowledge_workflow
# Conflicts: # apps/application/flow/workflow_manage.py # apps/common/utils/tool_code.py # ui/src/views/tool/component/ToolListContainer.vue
This commit is contained in:
commit
7e7e786bef
|
|
@ -5,7 +5,7 @@ on:
|
|||
inputs:
|
||||
dockerImageTag:
|
||||
description: 'Docker Image Tag'
|
||||
default: 'v2.0.2'
|
||||
default: 'v2.0.3'
|
||||
required: true
|
||||
architecture:
|
||||
description: 'Architecture'
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ on:
|
|||
inputs:
|
||||
dockerImageTag:
|
||||
description: 'Image Tag'
|
||||
default: 'v2.3.0-dev'
|
||||
default: 'v2.4.0-dev'
|
||||
required: true
|
||||
dockerImageTagWithLatest:
|
||||
description: '是否发布latest tag(正式发版时选择,测试版本切勿选择)'
|
||||
|
|
|
|||
|
|
@ -188,4 +188,5 @@ apps/models_provider/impl/*/icon/
|
|||
apps/models_provider/impl/tencent_model_provider/credential/stt.py
|
||||
apps/models_provider/impl/tencent_model_provider/model/stt.py
|
||||
tmp/
|
||||
config.yml
|
||||
config.yml
|
||||
.SANDBOX_BANNED_HOSTS
|
||||
|
|
@ -93,6 +93,7 @@ def event_content(response,
|
|||
reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '')
|
||||
else:
|
||||
reasoning_content_chunk = reasoning_chunk.get('reasoning_content')
|
||||
content_chunk = reasoning._normalize_content(content_chunk)
|
||||
all_text += content_chunk
|
||||
if reasoning_content_chunk is None:
|
||||
reasoning_content_chunk = ''
|
||||
|
|
@ -191,23 +192,17 @@ class BaseChatStep(IChatStep):
|
|||
manage, padding_problem_text, chat_user_id, chat_user_type,
|
||||
no_references_setting,
|
||||
model_setting,
|
||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
|
||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
|
||||
mcp_output_enable)
|
||||
else:
|
||||
return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model,
|
||||
paragraph_list,
|
||||
manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting,
|
||||
model_setting,
|
||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
|
||||
mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
|
||||
mcp_output_enable)
|
||||
|
||||
def get_details(self, manage, **kwargs):
|
||||
# 删除临时生成的MCP代码文件
|
||||
if self.context.get('execute_ids'):
|
||||
executor = ToolExecutor(CONFIG.get('SANDBOX'))
|
||||
# 清理工具代码文件,延时删除,避免文件被占用
|
||||
for tool_id in self.context.get('execute_ids'):
|
||||
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
|
||||
if os.path.exists(code_path):
|
||||
os.remove(code_path)
|
||||
return {
|
||||
'step_type': 'chat_step',
|
||||
'run_time': self.context['run_time'],
|
||||
|
|
@ -254,7 +249,6 @@ class BaseChatStep(IChatStep):
|
|||
if tool_enable:
|
||||
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
|
||||
self.context['tool_ids'] = tool_ids
|
||||
self.context['execute_ids'] = []
|
||||
for tool_id in tool_ids:
|
||||
tool = QuerySet(Tool).filter(id=tool_id).first()
|
||||
if tool is None or tool.is_active is False:
|
||||
|
|
@ -264,9 +258,8 @@ class BaseChatStep(IChatStep):
|
|||
params = json.loads(rsa_long_decrypt(tool.init_params))
|
||||
else:
|
||||
params = {}
|
||||
_id, tool_config = executor.get_tool_mcp_config(tool.code, params)
|
||||
tool_config = executor.get_tool_mcp_config(tool.code, params)
|
||||
|
||||
self.context['execute_ids'].append(_id)
|
||||
mcp_servers_config[str(tool.id)] = tool_config
|
||||
|
||||
if len(mcp_servers_config) > 0:
|
||||
|
|
@ -274,7 +267,6 @@ class BaseChatStep(IChatStep):
|
|||
|
||||
return None
|
||||
|
||||
|
||||
def get_stream_result(self, message_list: List[BaseMessage],
|
||||
chat_model: BaseChatModel = None,
|
||||
paragraph_list=None,
|
||||
|
|
@ -304,7 +296,8 @@ class BaseChatStep(IChatStep):
|
|||
else:
|
||||
# 处理 MCP 请求
|
||||
mcp_result = self._handle_mcp_request(
|
||||
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model, message_list,
|
||||
mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model,
|
||||
message_list,
|
||||
)
|
||||
if mcp_result:
|
||||
return mcp_result, True
|
||||
|
|
@ -329,7 +322,8 @@ class BaseChatStep(IChatStep):
|
|||
tool_ids=None,
|
||||
mcp_output_enable=True):
|
||||
chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list,
|
||||
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids,
|
||||
no_references_setting, problem_text, mcp_enable, mcp_tool_ids,
|
||||
mcp_servers, mcp_source, tool_enable, tool_ids,
|
||||
mcp_output_enable)
|
||||
chat_record_id = uuid.uuid7()
|
||||
r = StreamingHttpResponse(
|
||||
|
|
@ -404,7 +398,9 @@ class BaseChatStep(IChatStep):
|
|||
# 调用模型
|
||||
try:
|
||||
chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list,
|
||||
no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable)
|
||||
no_references_setting, problem_text, mcp_enable,
|
||||
mcp_tool_ids, mcp_servers, mcp_source, tool_enable,
|
||||
tool_ids, mcp_output_enable)
|
||||
if is_ai_chat:
|
||||
request_token = chat_model.get_num_tokens_from_messages(message_list)
|
||||
response_token = chat_model.get_num_tokens(chat_result.content)
|
||||
|
|
@ -416,10 +412,10 @@ class BaseChatStep(IChatStep):
|
|||
reasoning_result_end = reasoning.get_end_reasoning_content()
|
||||
content = reasoning_result.get('content') + reasoning_result_end.get('content')
|
||||
if 'reasoning_content' in chat_result.response_metadata:
|
||||
reasoning_content = chat_result.response_metadata.get('reasoning_content', '')
|
||||
reasoning_content = (chat_result.response_metadata.get('reasoning_content', '') or '')
|
||||
else:
|
||||
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get(
|
||||
'reasoning_content')
|
||||
reasoning_content = (reasoning_result.get('reasoning_content') or "") + (reasoning_result_end.get(
|
||||
'reasoning_content') or "")
|
||||
post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text,
|
||||
content, manage, self, padding_problem_text,
|
||||
reasoning_content=reasoning_content)
|
||||
|
|
|
|||
|
|
@ -95,6 +95,7 @@ class WorkFlowPostHandler:
|
|||
application_public_access_client.access_num = application_public_access_client.access_num + 1
|
||||
application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1
|
||||
application_public_access_client.save()
|
||||
self.chat_info = None
|
||||
|
||||
|
||||
class KnowledgeWorkflowPostHandler(WorkFlowPostHandler):
|
||||
|
|
|
|||
|
|
@ -108,9 +108,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor
|
|||
content = reasoning_result.get('content') + reasoning_result_end.get('content')
|
||||
meta = {**response.response_metadata, **response.additional_kwargs}
|
||||
if 'reasoning_content' in meta:
|
||||
reasoning_content = meta.get('reasoning_content', '')
|
||||
reasoning_content = (meta.get('reasoning_content', '') or '')
|
||||
else:
|
||||
reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content')
|
||||
reasoning_content = (reasoning_result.get('reasoning_content') or '') + (reasoning_result_end.get('reasoning_content') or '')
|
||||
_write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content)
|
||||
|
||||
|
||||
|
|
@ -233,7 +233,6 @@ class BaseChatNode(IChatNode):
|
|||
if tool_enable:
|
||||
if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP
|
||||
self.context['tool_ids'] = tool_ids
|
||||
self.context['execute_ids'] = []
|
||||
for tool_id in tool_ids:
|
||||
tool = QuerySet(Tool).filter(id=tool_id).first()
|
||||
if not tool.is_active:
|
||||
|
|
@ -243,9 +242,8 @@ class BaseChatNode(IChatNode):
|
|||
params = json.loads(rsa_long_decrypt(tool.init_params))
|
||||
else:
|
||||
params = {}
|
||||
_id, tool_config = executor.get_tool_mcp_config(tool.code, params)
|
||||
tool_config = executor.get_tool_mcp_config(tool.code, params)
|
||||
|
||||
self.context['execute_ids'].append(_id)
|
||||
mcp_servers_config[str(tool.id)] = tool_config
|
||||
|
||||
if len(mcp_servers_config) > 0:
|
||||
|
|
@ -307,14 +305,6 @@ class BaseChatNode(IChatNode):
|
|||
return result
|
||||
|
||||
def get_details(self, index: int, **kwargs):
|
||||
# 删除临时生成的MCP代码文件
|
||||
if self.context.get('execute_ids'):
|
||||
executor = ToolExecutor(CONFIG.get('SANDBOX'))
|
||||
# 清理工具代码文件,延时删除,避免文件被占用
|
||||
for tool_id in self.context.get('execute_ids'):
|
||||
code_path = f'{executor.sandbox_path}/execute/{tool_id}.py'
|
||||
if os.path.exists(code_path):
|
||||
os.remove(code_path)
|
||||
return {
|
||||
'name': self.node.properties.get('stepName'),
|
||||
"index": index,
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class BaseImageToVideoNode(IImageToVideoNode):
|
|||
def get_file_base64(self, image_url):
|
||||
try:
|
||||
if isinstance(image_url, list):
|
||||
image_url = image_url[0].get('file_id')
|
||||
image_url = image_url[0].get('file_id') if 'file_id' in image_url[0] else image_url[0].get('url')
|
||||
if isinstance(image_url, str) and not image_url.startswith('http'):
|
||||
file = QuerySet(File).filter(id=image_url).first()
|
||||
file_bytes = file.get_bytes()
|
||||
|
|
|
|||
|
|
@ -131,11 +131,18 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
image_list = data['image_list']
|
||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
file_id_list = [image.get('file_id') for image in image_list]
|
||||
|
||||
file_id_list = []
|
||||
url_list = []
|
||||
for image in image_list:
|
||||
if 'file_id' in image:
|
||||
file_id_list.append(image.get('file_id'))
|
||||
elif 'url' in image:
|
||||
url_list.append(image.get('url'))
|
||||
return HumanMessage(content=[
|
||||
{'type': 'text', 'text': data['question']},
|
||||
*[{'type': 'image_url', 'image_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list]
|
||||
|
||||
*[{'type': 'image_url', 'image_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list],
|
||||
*[{'type': 'image_url', 'image_url': {'url': url}} for url in url_list]
|
||||
])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
|
|
@ -155,13 +162,22 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
image_list = data['image_list']
|
||||
if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
image_base64_list = [file_id_to_base64(image.get('file_id')) for image in image_list]
|
||||
file_id_list = []
|
||||
url_list = []
|
||||
for image in image_list:
|
||||
if 'file_id' in image:
|
||||
file_id_list.append(image.get('file_id'))
|
||||
elif 'url' in image:
|
||||
url_list.append(image.get('url'))
|
||||
image_base64_list = [file_id_to_base64(file_id) for file_id in file_id_list]
|
||||
|
||||
return HumanMessage(
|
||||
content=[
|
||||
{'type': 'text', 'text': data['question']},
|
||||
*[{'type': 'image_url',
|
||||
'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for
|
||||
base64_image in image_base64_list]
|
||||
base64_image in image_base64_list],
|
||||
*[{'type': 'image_url', 'image_url': url} for url in url_list]
|
||||
])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
|
|
@ -177,13 +193,17 @@ class BaseImageUnderstandNode(IImageUnderstandNode):
|
|||
images.append({'type': 'image_url', 'image_url': {'url': image}})
|
||||
elif image is not None and len(image) > 0:
|
||||
for img in image:
|
||||
file_id = img['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
image_bytes = file.get_bytes()
|
||||
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_format = what(None, image_bytes)
|
||||
images.append(
|
||||
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
|
||||
if 'file_id' in img:
|
||||
file_id = img['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
image_bytes = file.get_bytes()
|
||||
base64_image = base64.b64encode(image_bytes).decode("utf-8")
|
||||
image_format = what(None, image_bytes)
|
||||
images.append(
|
||||
{'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}})
|
||||
elif 'url' in img and img['url'].startswith('http'):
|
||||
images.append(
|
||||
{'type': 'image_url', 'image_url': {'url': img["url"]}})
|
||||
return images
|
||||
|
||||
def generate_message_list(self, image_model, system: str, prompt: str, history_message, image):
|
||||
|
|
|
|||
|
|
@ -226,11 +226,14 @@ class BaseLoopNode(ILoopNode):
|
|||
def save_context(self, details, workflow_manage):
|
||||
self.context['loop_context_data'] = details.get('loop_context_data')
|
||||
self.context['loop_answer_data'] = details.get('loop_answer_data')
|
||||
self.context['loop_node_data'] = details.get('loop_node_data')
|
||||
self.context['result'] = details.get('result')
|
||||
self.context['params'] = details.get('params')
|
||||
self.context['run_time'] = details.get('run_time')
|
||||
self.context['index'] = details.get('current_index')
|
||||
self.context['item'] = details.get('current_item')
|
||||
for key, value in (details.get('loop_context_data') or {}).items():
|
||||
self.context[key] = value
|
||||
self.answer_text = ""
|
||||
|
||||
def get_answer_list(self) -> List[Answer] | None:
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ class BaseTextToVideoNode(ITextToVideoNode):
|
|||
self.context['dialogue_type'] = dialogue_type
|
||||
self.context['negative_prompt'] = self.generate_prompt_question(negative_prompt)
|
||||
video_urls = ttv_model.generate_video(question, negative_prompt)
|
||||
print('video_urls', video_urls)
|
||||
# 保存图片
|
||||
if video_urls is None:
|
||||
return NodeResult({'answer': gettext('Failed to generate video')}, {})
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import mimetypes
|
||||
|
||||
import time
|
||||
from functools import reduce
|
||||
from imghdr import what
|
||||
from typing import List, Dict
|
||||
|
||||
from django.db.models import QuerySet
|
||||
|
|
@ -12,7 +10,6 @@ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AI
|
|||
from application.flow.i_step_node import NodeResult, INode
|
||||
from application.flow.step_node.video_understand_step_node.i_video_understand_node import IVideoUnderstandNode
|
||||
from knowledge.models import File
|
||||
from models_provider.impl.volcanic_engine_model_provider.model.image import get_video_format
|
||||
from models_provider.tools import get_model_instance_by_model_workspace_id
|
||||
|
||||
|
||||
|
|
@ -134,11 +131,17 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode):
|
|||
# 增加对 None 和空列表的检查
|
||||
if not video_list or len(video_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
file_id_list = [video.get('file_id') for video in video_list]
|
||||
file_id_list = []
|
||||
url_list = []
|
||||
for image in video_list:
|
||||
if 'file_id' in image:
|
||||
file_id_list.append(image.get('file_id'))
|
||||
elif 'url' in image:
|
||||
url_list.append(image.get('url'))
|
||||
return HumanMessage(content=[
|
||||
{'type': 'text', 'text': data['question']},
|
||||
*[{'type': 'video_url', 'video_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list]
|
||||
|
||||
*[{'type': 'video_url', 'video_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list],
|
||||
*[{'type': 'video_url', 'video_url': {'url': url}} for url in url_list],
|
||||
])
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
|
||||
|
|
@ -158,6 +161,13 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode):
|
|||
video_list = data['video_list']
|
||||
if len(video_list) == 0 or data['dialogue_type'] == 'WORKFLOW':
|
||||
return HumanMessage(content=chat_record.problem_text)
|
||||
file_id_list = []
|
||||
url_list = []
|
||||
for image in video_list:
|
||||
if 'file_id' in image:
|
||||
file_id_list.append(image.get('file_id'))
|
||||
elif 'url' in image:
|
||||
url_list.append(image.get('url'))
|
||||
video_base64_list = [file_id_to_base64(video.get('file_id'), video_model) for video in video_list]
|
||||
return HumanMessage(
|
||||
content=[
|
||||
|
|
@ -177,11 +187,15 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode):
|
|||
videos.append({'type': 'video_url', 'video_url': {'url': image}})
|
||||
elif image is not None and len(image) > 0:
|
||||
for img in image:
|
||||
file_id = img['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
url = video_model.upload_file_and_get_url(file.get_bytes(), file.file_name)
|
||||
videos.append(
|
||||
{'type': 'video_url', 'video_url': {'url': url}})
|
||||
if 'file_id' in img:
|
||||
file_id = img['file_id']
|
||||
file = QuerySet(File).filter(id=file_id).first()
|
||||
url = video_model.upload_file_and_get_url(file.get_bytes(), file.file_name)
|
||||
videos.append(
|
||||
{'type': 'video_url', 'video_url': {'url': url}})
|
||||
elif 'url' in img and img['url'].startswith('http'):
|
||||
videos.append(
|
||||
{'type': 'video_url', 'video_url': {'url': img['url']}})
|
||||
return videos
|
||||
|
||||
def generate_message_list(self, video_model, system: str, prompt: str, history_message, video):
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@
|
|||
"""
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
import queue
|
||||
import threading
|
||||
from typing import Iterator
|
||||
|
||||
from django.http import StreamingHttpResponse
|
||||
|
|
@ -47,6 +48,21 @@ class Reasoning:
|
|||
return r
|
||||
return {'content': '', 'reasoning_content': ''}
|
||||
|
||||
def _normalize_content(self, content):
|
||||
"""将不同类型的内容统一转换为字符串"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# 处理包含多种内容类型的列表
|
||||
normalized_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get('type') == 'text':
|
||||
normalized_parts.append(item.get('text', ''))
|
||||
return ''.join(normalized_parts)
|
||||
else:
|
||||
return str(content)
|
||||
|
||||
def get_reasoning_content(self, chunk):
|
||||
# 如果没有开始思考过程标签那么就全是结果
|
||||
if self.reasoning_content_start_tag is None or len(self.reasoning_content_start_tag) == 0:
|
||||
|
|
@ -55,6 +71,7 @@ class Reasoning:
|
|||
# 如果没有结束思考过程标签那么就全部是思考过程
|
||||
if self.reasoning_content_end_tag is None or len(self.reasoning_content_end_tag) == 0:
|
||||
return {'content': '', 'reasoning_content': chunk.content}
|
||||
chunk.content = self._normalize_content(chunk.content)
|
||||
self.all_content += chunk.content
|
||||
if not self.reasoning_content_is_start and len(self.all_content) >= self.reasoning_content_start_tag_len:
|
||||
if self.all_content.startswith(self.reasoning_content_start_tag):
|
||||
|
|
@ -201,17 +218,6 @@ def to_stream_response_simple(stream_event):
|
|||
r['Cache-Control'] = 'no-cache'
|
||||
return r
|
||||
|
||||
tool_message_template = """
|
||||
<details>
|
||||
<summary>
|
||||
<strong>Called MCP Tool: <em>%s</em></strong>
|
||||
</summary>
|
||||
|
||||
%s
|
||||
|
||||
</details>
|
||||
|
||||
"""
|
||||
|
||||
tool_message_json_template = """
|
||||
```json
|
||||
|
|
@ -219,42 +225,142 @@ tool_message_json_template = """
|
|||
```
|
||||
"""
|
||||
|
||||
tool_message_complete_template = """
|
||||
<details>
|
||||
<summary>
|
||||
<strong>Called MCP Tool: <em>%s</em></strong>
|
||||
</summary>
|
||||
|
||||
def generate_tool_message_template(name, context):
|
||||
if '```' in context:
|
||||
return tool_message_template % (name, context)
|
||||
**Input:**
|
||||
%s
|
||||
|
||||
**Output:**
|
||||
%s
|
||||
|
||||
</details>
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def generate_tool_message_complete(name, input_content, output_content):
|
||||
"""生成包含输入和输出的工具消息模版"""
|
||||
# 格式化输入
|
||||
if '```' not in input_content:
|
||||
input_formatted = tool_message_json_template % input_content
|
||||
else:
|
||||
return tool_message_template % (name, tool_message_json_template % (context))
|
||||
input_formatted = input_content
|
||||
|
||||
# 格式化输出
|
||||
if '```' not in output_content:
|
||||
output_formatted = tool_message_json_template % output_content
|
||||
else:
|
||||
output_formatted = output_content
|
||||
|
||||
return tool_message_complete_template % (name, input_formatted, output_formatted)
|
||||
|
||||
|
||||
# 全局单例事件循环
|
||||
_global_loop = None
|
||||
_loop_thread = None
|
||||
_loop_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_global_loop():
|
||||
"""获取全局共享的事件循环"""
|
||||
global _global_loop, _loop_thread
|
||||
|
||||
with _loop_lock:
|
||||
if _global_loop is None:
|
||||
_global_loop = asyncio.new_event_loop()
|
||||
|
||||
def run_forever():
|
||||
asyncio.set_event_loop(_global_loop)
|
||||
_global_loop.run_forever()
|
||||
|
||||
_loop_thread = threading.Thread(target=run_forever, daemon=True, name="GlobalAsyncLoop")
|
||||
_loop_thread.start()
|
||||
|
||||
return _global_loop
|
||||
|
||||
|
||||
async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True):
|
||||
client = MultiServerMCPClient(json.loads(mcp_servers))
|
||||
tools = await client.get_tools()
|
||||
agent = create_react_agent(chat_model, tools)
|
||||
response = agent.astream({"messages": message_list}, stream_mode='messages')
|
||||
async for chunk in response:
|
||||
if mcp_output_enable and isinstance(chunk[0], ToolMessage):
|
||||
content = generate_tool_message_template(chunk[0].name, chunk[0].content)
|
||||
chunk[0].content = content
|
||||
yield chunk[0]
|
||||
if isinstance(chunk[0], AIMessageChunk):
|
||||
yield chunk[0]
|
||||
try:
|
||||
client = MultiServerMCPClient(json.loads(mcp_servers))
|
||||
tools = await client.get_tools()
|
||||
agent = create_react_agent(chat_model, tools)
|
||||
response = agent.astream({"messages": message_list}, stream_mode='messages')
|
||||
|
||||
# 用于存储工具调用信息
|
||||
tool_calls_info = {}
|
||||
|
||||
async for chunk in response:
|
||||
if isinstance(chunk[0], AIMessageChunk):
|
||||
tool_calls = chunk[0].additional_kwargs.get('tool_calls', [])
|
||||
for tool_call in tool_calls:
|
||||
tool_id = tool_call.get('id', '')
|
||||
if tool_id:
|
||||
# 保存工具调用的输入
|
||||
tool_calls_info[tool_id] = {
|
||||
'name': tool_call.get('function', {}).get('name', ''),
|
||||
'input': tool_call.get('function', {}).get('arguments', '')
|
||||
}
|
||||
yield chunk[0]
|
||||
|
||||
if mcp_output_enable and isinstance(chunk[0], ToolMessage):
|
||||
tool_id = chunk[0].tool_call_id
|
||||
if tool_id in tool_calls_info:
|
||||
# 合并输入和输出
|
||||
tool_info = tool_calls_info[tool_id]
|
||||
content = generate_tool_message_complete(
|
||||
tool_info['name'],
|
||||
tool_info['input'],
|
||||
chunk[0].content
|
||||
)
|
||||
chunk[0].content = content
|
||||
yield chunk[0]
|
||||
|
||||
except ExceptionGroup as eg:
|
||||
|
||||
def get_real_error(exc):
|
||||
if isinstance(exc, ExceptionGroup):
|
||||
return get_real_error(exc.exceptions[0])
|
||||
return exc
|
||||
|
||||
real_error = get_real_error(eg)
|
||||
error_msg = f"{type(real_error).__name__}: {str(real_error)}"
|
||||
raise RuntimeError(error_msg) from None
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"{type(e).__name__}: {str(e)}"
|
||||
raise RuntimeError(error_msg) from None
|
||||
|
||||
|
||||
def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
|
||||
while True:
|
||||
try:
|
||||
chunk = loop.run_until_complete(anext_async(async_gen))
|
||||
yield chunk
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
finally:
|
||||
loop.close()
|
||||
"""使用全局事件循环,不创建新实例"""
|
||||
result_queue = queue.Queue()
|
||||
loop = get_global_loop() # 使用共享循环
|
||||
|
||||
async def _run():
|
||||
try:
|
||||
async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable)
|
||||
async for chunk in async_gen:
|
||||
result_queue.put(('data', chunk))
|
||||
except Exception as e:
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
result_queue.put(('error', e))
|
||||
finally:
|
||||
result_queue.put(('done', None))
|
||||
|
||||
# 在全局循环中调度任务
|
||||
asyncio.run_coroutine_threadsafe(_run(), loop)
|
||||
|
||||
while True:
|
||||
msg_type, data = result_queue.get()
|
||||
if msg_type == 'done':
|
||||
break
|
||||
if msg_type == 'error':
|
||||
raise data
|
||||
yield data
|
||||
|
||||
|
||||
async def anext_async(agen):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@
|
|||
import concurrent
|
||||
import json
|
||||
import threading
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
|
@ -26,6 +25,7 @@ from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult,
|
|||
from application.flow.step_node import get_node
|
||||
from common.handle.base_to_response import BaseToResponse
|
||||
from common.handle.impl.response.system_to_response import SystemToResponse
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
executor = ThreadPoolExecutor(max_workers=200)
|
||||
|
||||
|
|
@ -234,25 +234,62 @@ class WorkflowManage:
|
|||
非流式响应
|
||||
@return: 结果
|
||||
"""
|
||||
self.run_chain_async(None, None, language)
|
||||
while self.is_run():
|
||||
pass
|
||||
details = self.get_runtime_details()
|
||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||||
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
|
||||
'answer_tokens' in row and row.get('answer_tokens') is not None])
|
||||
answer_text_list = self.get_answer_text_list()
|
||||
answer_text = '\n\n'.join(
|
||||
'\n\n'.join([a.get('content') for a in answer]) for answer in
|
||||
answer_text_list)
|
||||
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
|
||||
self.work_flow_post_handler.handler(self)
|
||||
return self.base_to_response.to_block_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'], answer_text, True
|
||||
, message_tokens, answer_tokens,
|
||||
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
other_params={'answer_list': answer_list})
|
||||
try:
|
||||
self.run_chain_async(None, None, language)
|
||||
while self.is_run():
|
||||
pass
|
||||
details = self.get_runtime_details()
|
||||
message_tokens = sum([row.get('message_tokens') for row in details.values() if
|
||||
'message_tokens' in row and row.get('message_tokens') is not None])
|
||||
answer_tokens = sum([row.get('answer_tokens') for row in details.values() if
|
||||
'answer_tokens' in row and row.get('answer_tokens') is not None])
|
||||
answer_text_list = self.get_answer_text_list()
|
||||
answer_text = '\n\n'.join(
|
||||
'\n\n'.join([a.get('content') for a in answer]) for answer in
|
||||
answer_text_list)
|
||||
answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, [])
|
||||
self.work_flow_post_handler.handler(self)
|
||||
|
||||
res = self.base_to_response.to_block_response(self.params['chat_id'],
|
||||
self.params['chat_record_id'], answer_text, True
|
||||
, message_tokens, answer_tokens,
|
||||
_status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
other_params={'answer_list': answer_list})
|
||||
finally:
|
||||
self._cleanup()
|
||||
return res
|
||||
|
||||
def _cleanup(self):
|
||||
"""清理所有对象引用"""
|
||||
# 清理列表
|
||||
self.future_list.clear()
|
||||
self.field_list.clear()
|
||||
self.global_field_list.clear()
|
||||
self.chat_field_list.clear()
|
||||
self.image_list.clear()
|
||||
self.video_list.clear()
|
||||
self.document_list.clear()
|
||||
self.audio_list.clear()
|
||||
self.other_list.clear()
|
||||
if hasattr(self, 'node_context'):
|
||||
self.node_context.clear()
|
||||
|
||||
# 清理字典
|
||||
self.context.clear()
|
||||
self.chat_context.clear()
|
||||
self.form_data.clear()
|
||||
|
||||
# 清理对象引用
|
||||
self.node_chunk_manage = None
|
||||
self.work_flow_post_handler = None
|
||||
self.flow = None
|
||||
self.start_node = None
|
||||
self.current_node = None
|
||||
self.current_result = None
|
||||
self.chat_record = None
|
||||
self.base_to_response = None
|
||||
self.params = None
|
||||
self.lock = None
|
||||
|
||||
def run_stream(self, current_node, node_result_future, language='zh'):
|
||||
"""
|
||||
|
|
@ -307,6 +344,7 @@ class WorkflowManage:
|
|||
'',
|
||||
[],
|
||||
'', True, message_tokens, answer_tokens, {})
|
||||
self._cleanup()
|
||||
|
||||
def run_chain_async(self, current_node, node_result_future, language='zh'):
|
||||
future = executor.submit(self.run_chain_manage, current_node, node_result_future, language)
|
||||
|
|
@ -345,7 +383,7 @@ class WorkflowManage:
|
|||
current_node, node_result_future)
|
||||
return result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
return None
|
||||
|
||||
def hand_node_result(self, current_node, node_result_future):
|
||||
|
|
@ -357,7 +395,7 @@ class WorkflowManage:
|
|||
list(result)
|
||||
return current_result
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
self.status = 500
|
||||
current_node.get_write_error_context(e)
|
||||
self.answer += str(e)
|
||||
|
|
@ -435,9 +473,10 @@ class WorkflowManage:
|
|||
return current_result
|
||||
except Exception as e:
|
||||
# 添加节点
|
||||
traceback.print_exc()
|
||||
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
chunk = self.base_to_response.to_stream_chunk_response(self.params.get('chat_id'),
|
||||
self.params.get('chat_record_id'),
|
||||
self.params.get('chat_id'),
|
||||
current_node.id,
|
||||
current_node.up_node_id_list,
|
||||
'Exception:' + str(e), False, 0, 0,
|
||||
|
|
|
|||
|
|
@ -118,3 +118,37 @@ class ApplicationStatisticsSerializer(serializers.Serializer):
|
|||
days.append(current_date.strftime('%Y-%m-%d'))
|
||||
current_date += datetime.timedelta(days=1)
|
||||
return days
|
||||
|
||||
def get_token_usage_statistics(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
start_time = self.get_start_time()
|
||||
end_time = self.get_end_time()
|
||||
get_token_usage = native_search(
|
||||
{'default_sql': QuerySet(model=get_dynamics_model(
|
||||
{'application_chat.application_id': models.UUIDField(),
|
||||
'application_chat_record.create_time': models.DateTimeField()})).filter(
|
||||
**{'application_chat.application_id': self.data.get('application_id'),
|
||||
'application_chat_record.create_time__gte': start_time,
|
||||
'application_chat_record.create_time__lte': end_time}
|
||||
)},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'get_token_usage.sql')))
|
||||
return get_token_usage
|
||||
|
||||
def get_top_questions_statistics(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
start_time = self.get_start_time()
|
||||
end_time = self.get_end_time()
|
||||
get_top_questions = native_search(
|
||||
{'default_sql': QuerySet(model=get_dynamics_model(
|
||||
{'application_chat.application_id': models.UUIDField(),
|
||||
'application_chat_record.create_time': models.DateTimeField()})).filter(
|
||||
**{'application_chat.application_id': self.data.get('application_id'),
|
||||
'application_chat_record.create_time__gte': start_time,
|
||||
'application_chat_record.create_time__lte': end_time}
|
||||
)},
|
||||
select_string=get_file_content(
|
||||
os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'top_questions.sql')))
|
||||
return get_top_questions
|
||||
|
|
@ -6,7 +6,7 @@
|
|||
@date:2025/6/9 13:42
|
||||
@desc:
|
||||
"""
|
||||
from datetime import datetime
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from django.core.cache import cache
|
||||
|
|
@ -20,6 +20,7 @@ from application.models import Application, ChatRecord, Chat, ApplicationVersion
|
|||
from application.serializers.application_chat import ChatCountSerializer
|
||||
from common.constants.cache_version import Cache_Version
|
||||
from common.database_model_manage.database_model_manage import DatabaseModelManage
|
||||
from common.encoder.encoder import SystemEncoder
|
||||
from common.exception.app_exception import ChatException
|
||||
from knowledge.models import Document
|
||||
from models_provider.models import Model
|
||||
|
|
@ -226,10 +227,69 @@ class ChatInfo:
|
|||
chat_record.save()
|
||||
ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat()
|
||||
|
||||
def to_dict(self):
|
||||
|
||||
return {
|
||||
'chat_id': self.chat_id,
|
||||
'chat_user_id': self.chat_user_id,
|
||||
'chat_user_type': self.chat_user_type,
|
||||
'knowledge_id_list': self.knowledge_id_list,
|
||||
'exclude_document_id_list': self.exclude_document_id_list,
|
||||
'application_id': self.application_id,
|
||||
'chat_record_list': [self.chat_record_to_map(c) for c in self.chat_record_list][-20:],
|
||||
'debug': self.debug
|
||||
}
|
||||
|
||||
def chat_record_to_map(self, chat_record):
|
||||
return {'id': chat_record.id,
|
||||
'chat_id': chat_record.chat_id,
|
||||
'vote_status': chat_record.vote_status,
|
||||
'problem_text': chat_record.problem_text,
|
||||
'answer_text': chat_record.answer_text,
|
||||
'answer_text_list': chat_record.answer_text_list,
|
||||
'message_tokens': chat_record.message_tokens,
|
||||
'answer_tokens': chat_record.answer_tokens,
|
||||
'const': chat_record.const,
|
||||
'details': chat_record.details,
|
||||
'improve_paragraph_id_list': chat_record.improve_paragraph_id_list,
|
||||
'run_time': chat_record.run_time,
|
||||
'index': chat_record.index}
|
||||
|
||||
@staticmethod
|
||||
def map_to_chat_record(chat_record_dict):
|
||||
return ChatRecord(id=chat_record_dict.get('id'),
|
||||
chat_id=chat_record_dict.get('chat_id'),
|
||||
vote_status=chat_record_dict.get('vote_status'),
|
||||
problem_text=chat_record_dict.get('problem_text'),
|
||||
answer_text=chat_record_dict.get('answer_text'),
|
||||
answer_text_list=chat_record_dict.get('answer_text_list'),
|
||||
message_tokens=chat_record_dict.get('message_tokens'),
|
||||
answer_tokens=chat_record_dict.get('answer_tokens'),
|
||||
const=chat_record_dict.get('const'),
|
||||
details=chat_record_dict.get('details'),
|
||||
improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'),
|
||||
run_time=chat_record_dict.get('run_time'),
|
||||
index=chat_record_dict.get('index'), )
|
||||
|
||||
def set_cache(self):
|
||||
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(),
|
||||
cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self.to_dict(),
|
||||
version=Cache_Version.CHAT_INFO.get_version(),
|
||||
timeout=60 * 30)
|
||||
|
||||
@staticmethod
|
||||
def map_to_chat_info(chat_info_dict):
|
||||
c = ChatInfo(chat_info_dict.get('chat_id'), chat_info_dict.get('chat_user_id'),
|
||||
chat_info_dict.get('chat_user_type'), chat_info_dict.get('knowledge_id_list'),
|
||||
chat_info_dict.get('exclude_document_id_list'),
|
||||
chat_info_dict.get('application_id'),
|
||||
debug=chat_info_dict.get('debug'))
|
||||
c.chat_record_list = [ChatInfo.map_to_chat_record(c_r) for c_r in chat_info_dict.get('chat_record_list')]
|
||||
return c
|
||||
|
||||
@staticmethod
|
||||
def get_cache(chat_id):
|
||||
return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version())
|
||||
chat_info_dict = cache.get(Cache_Version.CHAT.get_key(key=chat_id),
|
||||
version=Cache_Version.CHAT_INFO.get_version())
|
||||
if chat_info_dict:
|
||||
return ChatInfo.map_to_chat_info(chat_info_dict)
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
|
||||
SELECT
|
||||
SUM(application_chat_record.message_tokens + application_chat_record.answer_tokens) as "token_usage",
|
||||
COALESCE(application_chat.asker->>'username', '游客') as "username"
|
||||
FROM
|
||||
application_chat_record application_chat_record
|
||||
LEFT JOIN application_chat application_chat ON application_chat."id" = application_chat_record.chat_id
|
||||
${default_sql}
|
||||
GROUP BY
|
||||
COALESCE(application_chat.asker->>'username', '游客')
|
||||
ORDER BY
|
||||
"token_usage" DESC
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
SELECT COUNT(application_chat_record."id") AS chat_record_count,
|
||||
COALESCE(application_chat.asker ->>'username', '游客') AS username
|
||||
FROM application_chat_record application_chat_record
|
||||
LEFT JOIN application_chat application_chat ON application_chat."id" = application_chat_record.chat_id
|
||||
${default_sql}
|
||||
GROUP BY
|
||||
COALESCE (application_chat.asker->>'username', '游客')
|
||||
ORDER BY
|
||||
chat_record_count DESC,
|
||||
username ASC
|
||||
|
||||
|
|
@ -13,6 +13,8 @@ urlpatterns = [
|
|||
path('workspace/<str:workspace_id>/application/<str:application_id>/publish', views.ApplicationAPI.Publish.as_view()),
|
||||
path('workspace/<str:workspace_id>/application/<str:application_id>/application_key', views.ApplicationKey.as_view()),
|
||||
path('workspace/<str:workspace_id>/application/<str:application_id>/application_stats', views.ApplicationStats.as_view()),
|
||||
path('workspace/<str:workspace_id>/application/<str:application_id>/application_token_usage', views.ApplicationStats.TokenUsageStatistics.as_view()),
|
||||
path('workspace/<str:workspace_id>/application/<str:application_id>/top_questions', views.ApplicationStats.TopQuestionsStatistics.as_view()),
|
||||
path('workspace/<str:workspace_id>/application/<str:application_id>/application_key/<str:api_key_id>', views.ApplicationKey.Operate.as_view()),
|
||||
path('workspace/<str:workspace_id>/application/<str:application_id>/export', views.ApplicationAPI.Export.as_view()),
|
||||
path('workspace/<str:workspace_id>/application/<str:application_id>/application_version', views.ApplicationVersionView.as_view()),
|
||||
|
|
|
|||
|
|
@ -125,8 +125,8 @@ class OpenView(APIView):
|
|||
responses=None,
|
||||
tags=[_('Application')] # type: ignore
|
||||
)
|
||||
@has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(),
|
||||
PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(),
|
||||
@has_permissions(PermissionConstants.APPLICATION_READ.get_workspace_application_permission(),
|
||||
PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(),
|
||||
ViewPermission([RoleConstants.USER.get_workspace_role()],
|
||||
[PermissionConstants.APPLICATION.get_workspace_application_permission()],
|
||||
CompareConstants.AND),
|
||||
|
|
@ -167,8 +167,8 @@ class PromptGenerateView(APIView):
|
|||
responses=None,
|
||||
tags=[_('Application')] # type: ignore
|
||||
)
|
||||
@has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(),
|
||||
PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(),
|
||||
@has_permissions(PermissionConstants.APPLICATION_READ.get_workspace_application_permission(),
|
||||
PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(),
|
||||
ViewPermission([RoleConstants.USER.get_workspace_role()],
|
||||
[PermissionConstants.APPLICATION.get_workspace_application_permission()],
|
||||
CompareConstants.AND),
|
||||
|
|
|
|||
|
|
@ -93,8 +93,8 @@ class ApplicationChatRecordOperateAPI(APIView):
|
|||
)
|
||||
@has_permissions(PermissionConstants.APPLICATION_CHAT_LOG_READ.get_workspace_application_permission(),
|
||||
PermissionConstants.APPLICATION_CHAT_LOG_READ.get_workspace_permission_workspace_manage_role(),
|
||||
PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(),
|
||||
PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(),
|
||||
PermissionConstants.APPLICATION_READ.get_workspace_application_permission(),
|
||||
PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(),
|
||||
ViewPermission([RoleConstants.USER.get_workspace_role()],
|
||||
[PermissionConstants.APPLICATION.get_workspace_application_permission()],
|
||||
CompareConstants.AND),
|
||||
|
|
|
|||
|
|
@ -46,3 +46,58 @@ class ApplicationStats(APIView):
|
|||
'end_time': request.query_params.get(
|
||||
'end_time')
|
||||
}).get_chat_record_aggregate_trend())
|
||||
|
||||
class TokenUsageStatistics(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
|
||||
# 应用的token使用统计 根据人的使用数排序
|
||||
@extend_schema(
|
||||
methods=['GET'],
|
||||
description=_('Application token usage statistics'),
|
||||
summary=_('Application token usage statistics'),
|
||||
operation_id=_('Application token usage statistics'), # type: ignore
|
||||
parameters=ApplicationStatsAPI.get_parameters(),
|
||||
responses=ApplicationStatsAPI.get_response(),
|
||||
tags=[_('Application')] # type: ignore
|
||||
)
|
||||
@has_permissions(PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_application_permission(),
|
||||
PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_permission_workspace_manage_role(),
|
||||
ViewPermission([RoleConstants.USER.get_workspace_role()],
|
||||
[PermissionConstants.APPLICATION.get_workspace_application_permission()],
|
||||
CompareConstants.AND),
|
||||
RoleConstants.WORKSPACE_MANAGE.get_workspace_role())
|
||||
def get(self, request: Request, workspace_id: str, application_id: str):
|
||||
return result.success(
|
||||
ApplicationStatisticsSerializer(data={'application_id': application_id, 'workspace_id': workspace_id,
|
||||
'start_time': request.query_params.get(
|
||||
'start_time'),
|
||||
'end_time': request.query_params.get(
|
||||
'end_time')
|
||||
}).get_token_usage_statistics())
|
||||
|
||||
class TopQuestionsStatistics(APIView):
|
||||
authentication_classes = [TokenAuth]
|
||||
# 应用的top问题统计
|
||||
@extend_schema(
|
||||
methods=['GET'],
|
||||
description=_('Application top question statistics'),
|
||||
summary=_('Application top question statistics'),
|
||||
operation_id=_('Application top question statistics'), # type: ignore
|
||||
parameters=ApplicationStatsAPI.get_parameters(),
|
||||
responses=ApplicationStatsAPI.get_response(),
|
||||
tags=[_('Application')] # type: ignore
|
||||
)
|
||||
@has_permissions(PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_application_permission(),
|
||||
PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_permission_workspace_manage_role(),
|
||||
ViewPermission([RoleConstants.USER.get_workspace_role()],
|
||||
[PermissionConstants.APPLICATION.get_workspace_application_permission()],
|
||||
CompareConstants.AND),
|
||||
RoleConstants.WORKSPACE_MANAGE.get_workspace_role())
|
||||
def get(self, request: Request, workspace_id: str, application_id: str):
|
||||
return result.success(
|
||||
ApplicationStatisticsSerializer(data={'application_id': application_id, 'workspace_id': workspace_id,
|
||||
'start_time': request.query_params.get(
|
||||
'start_time'),
|
||||
'end_time': request.query_params.get(
|
||||
'end_time')
|
||||
}).get_top_questions_statistics())
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2023/9/4 11:16
|
||||
@desc: 认证类
|
||||
"""
|
||||
import traceback
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
|
|
@ -18,6 +17,7 @@ from rest_framework.authentication import TokenAuthentication
|
|||
|
||||
from common.exception.app_exception import AppAuthenticationFailed, AppEmbedIdentityFailed, AppChatNumOutOfBoundsFailed, \
|
||||
AppApiException
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
token_cache = cache.caches['default']
|
||||
|
||||
|
|
@ -88,7 +88,7 @@ class TokenAuth(TokenAuthentication):
|
|||
return handle.handle(request, token, token_details.get_token_details)
|
||||
raise AppAuthenticationFailed(1002, _('Authentication information is incorrect! illegal user'))
|
||||
except Exception as e:
|
||||
traceback.print_stack()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed) or isinstance(e,
|
||||
AppApiException):
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class ChatAuthentication:
|
|||
value = json.dumps(self.to_dict())
|
||||
authentication = encrypt(value)
|
||||
cache_key = hashlib.sha256(authentication.encode()).hexdigest()
|
||||
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.value, timeout=60 * 60 * 2)
|
||||
authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.get_version(), timeout=60 * 60 * 2)
|
||||
return authentication
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -7,18 +7,24 @@
|
|||
@desc:
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
class MKTokenizer:
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def encode(self, text):
|
||||
return self.tokenizer.encode(text).ids
|
||||
|
||||
|
||||
class TokenizerManage:
|
||||
tokenizer = None
|
||||
|
||||
@staticmethod
|
||||
def get_tokenizer():
|
||||
from transformers import BertTokenizer
|
||||
if TokenizerManage.tokenizer is None:
|
||||
TokenizerManage.tokenizer = BertTokenizer.from_pretrained(
|
||||
'bert-base-cased',
|
||||
cache_dir="/opt/maxkb-app/model/tokenizer",
|
||||
local_files_only=True,
|
||||
resume_download=False,
|
||||
force_download=False)
|
||||
return TokenizerManage.tokenizer
|
||||
from tokenizers import Tokenizer
|
||||
# 创建Tokenizer
|
||||
model_path = os.path.join("/opt/maxkb-app", "model", "tokenizer", "models--bert-base-cased")
|
||||
with open(f"{model_path}/refs/main", encoding="utf-8") as f: snapshot = f.read()
|
||||
TokenizerManage.tokenizer = Tokenizer.from_file(f"{model_path}/snapshots/{snapshot}/tokenizer.json")
|
||||
return MKTokenizer(TokenizerManage.tokenizer)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ class Cache_Version(Enum):
|
|||
# 对话
|
||||
CHAT = "CHAT", lambda key: key
|
||||
|
||||
CHAT_INFO = "CHAT_INFO", lambda key: key
|
||||
|
||||
CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key
|
||||
|
||||
# 应用API KEY
|
||||
|
|
|
|||
|
|
@ -77,5 +77,5 @@ class HTMLSplitHandle(BaseSplitHandle):
|
|||
content = buffer.decode(encoding)
|
||||
return html2text(content)
|
||||
except BaseException as e:
|
||||
traceback.print_exception(e)
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
return f'{e}'
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ class Services(TextChoices):
|
|||
gunicorn = 'gunicorn', 'gunicorn'
|
||||
celery_default = 'celery_default', 'celery_default'
|
||||
local_model = 'local_model', 'local_model'
|
||||
scheduler = 'scheduler', 'scheduler'
|
||||
web = 'web', 'web'
|
||||
celery = 'celery', 'celery'
|
||||
celery_model = 'celery_model', 'celery_model'
|
||||
|
|
@ -25,7 +24,6 @@ class Services(TextChoices):
|
|||
cls.gunicorn.value: services.GunicornService,
|
||||
cls.celery_default: services.CeleryDefaultService,
|
||||
cls.local_model: services.GunicornLocalModelService,
|
||||
cls.scheduler: services.SchedulerService,
|
||||
}
|
||||
return services_map.get(name)
|
||||
|
||||
|
|
@ -41,13 +39,10 @@ class Services(TextChoices):
|
|||
def task_services(cls):
|
||||
return cls.celery_services()
|
||||
|
||||
@classmethod
|
||||
def scheduler_services(cls):
|
||||
return [cls.scheduler]
|
||||
|
||||
@classmethod
|
||||
def all_services(cls):
|
||||
return cls.web_services() + cls.task_services() + cls.scheduler_services()
|
||||
return cls.web_services() + cls.task_services()
|
||||
|
||||
@classmethod
|
||||
def export_services_values(cls):
|
||||
|
|
@ -101,7 +96,7 @@ class BaseActionCommand(BaseCommand):
|
|||
)
|
||||
parser.add_argument('-d', '--daemon', nargs="?", const=True)
|
||||
parser.add_argument('-w', '--worker', type=int, nargs="?",
|
||||
default=2 if os.cpu_count() > 6 else math.floor(os.cpu_count() / 2))
|
||||
default=3 if os.cpu_count() > 6 else max(1, math.floor(os.cpu_count() / 2)))
|
||||
parser.add_argument('-f', '--force', nargs="?", const=True)
|
||||
|
||||
def initial_util(self, *args, **options):
|
||||
|
|
|
|||
|
|
@ -18,14 +18,17 @@ class GunicornService(BaseService):
|
|||
|
||||
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
|
||||
bind = f'{HTTP_HOST}:{HTTP_PORT}'
|
||||
max_requests = 10240 if int(self.worker) > 1 else 0
|
||||
cmd = [
|
||||
'gunicorn', 'maxkb.wsgi:application',
|
||||
'-b', bind,
|
||||
'-k', 'gthread',
|
||||
'--threads', '200',
|
||||
'-w', str(self.worker),
|
||||
'--max-requests', '10240',
|
||||
'--max-requests', str(max_requests),
|
||||
'--max-requests-jitter', '2048',
|
||||
'--timeout', '0',
|
||||
'--graceful-timeout', '0',
|
||||
'--access-logformat', log_format,
|
||||
'--access-logfile', '/dev/null',
|
||||
'--error-logfile', '-'
|
||||
|
|
|
|||
|
|
@ -27,14 +27,17 @@ class GunicornLocalModelService(BaseService):
|
|||
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
worker = CONFIG.get("LOCAL_MODEL_HOST_WORKER", 1)
|
||||
max_requests = 10240 if int(worker) > 1 else 0
|
||||
cmd = [
|
||||
'gunicorn', 'maxkb.wsgi:application',
|
||||
'-b', bind,
|
||||
'-k', 'gthread',
|
||||
'--threads', '200',
|
||||
'-w', str(worker),
|
||||
'--max-requests', '10240',
|
||||
'--max-requests', str(max_requests),
|
||||
'--max-requests-jitter', '2048',
|
||||
'--timeout', '0',
|
||||
'--graceful-timeout', '0',
|
||||
'--access-logformat', log_format,
|
||||
'--access-logfile', '/dev/null',
|
||||
'--error-logfile', '-'
|
||||
|
|
|
|||
|
|
@ -18,14 +18,17 @@ class SchedulerService(BaseService):
|
|||
|
||||
log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s '
|
||||
bind = f'127.0.0.1:6060'
|
||||
max_requests = 10240 if int(self.worker) > 1 else 0
|
||||
cmd = [
|
||||
'gunicorn', 'maxkb.wsgi:application',
|
||||
'-b', bind,
|
||||
'-k', 'gthread',
|
||||
'--threads', '200',
|
||||
'-w', str(self.worker),
|
||||
'--max-requests', '10240',
|
||||
'--max-requests', str(max_requests),
|
||||
'--max-requests-jitter', '2048',
|
||||
'--timeout', '0',
|
||||
'--graceful-timeout', '0',
|
||||
'--access-logformat', log_format,
|
||||
'--access-logfile', '/dev/null',
|
||||
'--error-logfile', '-'
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from django.db import models
|
|||
|
||||
|
||||
class AppModelMixin(models.Model):
|
||||
objects = models.Manager()
|
||||
|
||||
create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True, db_index=True)
|
||||
update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True, db_index=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
"""
|
||||
import base64
|
||||
import threading
|
||||
from functools import lru_cache
|
||||
|
||||
from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher
|
||||
from Crypto.PublicKey import RSA
|
||||
|
|
@ -70,7 +71,7 @@ def encrypt(msg, public_key: str | None = None):
|
|||
"""
|
||||
if public_key is None:
|
||||
public_key = get_key_pair().get('key')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(public_key))
|
||||
cipher = _get_encrypt_cipher(public_key)
|
||||
encrypt_msg = cipher.encrypt(msg.encode("utf-8"))
|
||||
return base64.b64encode(encrypt_msg).decode()
|
||||
|
||||
|
|
@ -84,56 +85,69 @@ def decrypt(msg, pri_key: str | None = None):
|
|||
"""
|
||||
if pri_key is None:
|
||||
pri_key = get_key_pair().get('value')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
cipher = _get_cipher(pri_key)
|
||||
decrypt_data = cipher.decrypt(base64.b64decode(msg), 0)
|
||||
return decrypt_data.decode("utf-8")
|
||||
|
||||
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
def _get_encrypt_cipher(public_key: str):
|
||||
"""缓存加密 cipher 对象"""
|
||||
return PKCS1_cipher.new(RSA.importKey(extern_key=public_key, passphrase=secret_code))
|
||||
|
||||
|
||||
def rsa_long_encrypt(message, public_key: str | None = None, length=200):
|
||||
"""
|
||||
超长文本加密
|
||||
|
||||
:param message: 需要加密的字符串
|
||||
:param public_key 公钥
|
||||
:param length: 1024bit的证书用100, 2048bit的证书用 200
|
||||
:param length: 1024bit的证书用100, 2048bit的证书用 200
|
||||
:return: 加密后的数据
|
||||
"""
|
||||
# 读取公钥
|
||||
if public_key is None:
|
||||
public_key = get_key_pair().get('key')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key,
|
||||
passphrase=secret_code))
|
||||
# 处理:Plaintext is too long. 分段加密
|
||||
|
||||
cipher = _get_encrypt_cipher(public_key)
|
||||
|
||||
if len(message) <= length:
|
||||
# 对编码的数据进行加密,并通过base64进行编码
|
||||
result = base64.b64encode(cipher.encrypt(message.encode('utf-8')))
|
||||
else:
|
||||
rsa_text = []
|
||||
# 对编码后的数据进行切片,原因:加密长度不能过长
|
||||
for i in range(0, len(message), length):
|
||||
cont = message[i:i + length]
|
||||
# 对切片后的数据进行加密,并新增到text后面
|
||||
rsa_text.append(cipher.encrypt(cont.encode('utf-8')))
|
||||
# 加密完进行拼接
|
||||
cipher_text = b''.join(rsa_text)
|
||||
# base64进行编码
|
||||
result = base64.b64encode(cipher_text)
|
||||
|
||||
return result.decode()
|
||||
|
||||
|
||||
@lru_cache(maxsize=2)
|
||||
def _get_cipher(pri_key: str):
|
||||
"""缓存 cipher 对象,避免重复创建"""
|
||||
return PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
|
||||
|
||||
def rsa_long_decrypt(message, pri_key: str | None = None, length=256):
|
||||
"""
|
||||
超长文本解密,默认不加密
|
||||
超长文本解密,优化内存使用
|
||||
:param message: 需要解密的数据
|
||||
:param pri_key: 秘钥
|
||||
:param length : 1024bit的证书用128,2048bit证书用256位
|
||||
:param length : 1024bit的证书用128,2048bit证书用256位
|
||||
:return: 解密后的数据
|
||||
"""
|
||||
if pri_key is None:
|
||||
pri_key = get_key_pair().get('value')
|
||||
cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code))
|
||||
|
||||
cipher = _get_cipher(pri_key)
|
||||
base64_de = base64.b64decode(message)
|
||||
res = []
|
||||
|
||||
# 使用 bytearray 减少内存分配
|
||||
result = bytearray()
|
||||
for i in range(0, len(base64_de), length):
|
||||
res.append(cipher.decrypt(base64_de[i:i + length], 0))
|
||||
return b"".join(res).decode()
|
||||
result.extend(cipher.decrypt(base64_de[i:i + length], 0))
|
||||
|
||||
return result.decode()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,51 +1,77 @@
|
|||
# coding=utf-8
|
||||
import ast
|
||||
import base64
|
||||
import gzip
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
import socket
|
||||
import uuid_utils.compat as uuid
|
||||
from common.utils.logger import maxkb_logger
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from maxkb.const import BASE_DIR, CONFIG
|
||||
from maxkb.const import PROJECT_DIR
|
||||
from textwrap import dedent
|
||||
|
||||
python_directory = sys.executable
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
|
||||
def __init__(self, sandbox=False):
|
||||
self.sandbox = sandbox
|
||||
if sandbox:
|
||||
self.sandbox_path = '/opt/maxkb-app/sandbox'
|
||||
self.sandbox_path = CONFIG.get("SANDBOX_HOME", '/opt/maxkb-app/sandbox')
|
||||
self.user = 'sandbox'
|
||||
else:
|
||||
self.sandbox_path = os.path.join(PROJECT_DIR, 'data', 'sandbox')
|
||||
self.user = None
|
||||
self._createdir()
|
||||
if self.sandbox:
|
||||
os.system(f"chown -R {self.user}:root {self.sandbox_path}")
|
||||
self.banned_keywords = CONFIG.get("SANDBOX_PYTHON_BANNED_KEYWORDS", 'nothing_is_banned').split(',');
|
||||
banned_hosts = CONFIG.get("SANDBOX_PYTHON_BANNED_HOSTS", '').strip()
|
||||
self.sandbox_so_path = f'{self.sandbox_path}/sandbox.so'
|
||||
try:
|
||||
if banned_hosts:
|
||||
hostname = socket.gethostname()
|
||||
local_ip = socket.gethostbyname(hostname)
|
||||
banned_hosts = f"{banned_hosts},{hostname},{local_ip}"
|
||||
except Exception:
|
||||
pass
|
||||
self.banned_hosts = banned_hosts
|
||||
self._init_dir()
|
||||
except Exception as e:
|
||||
# 本机忽略异常,容器内不忽略
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if self.sandbox:
|
||||
raise e
|
||||
|
||||
def _createdir(self):
|
||||
old_mask = os.umask(0o077)
|
||||
def _init_dir(self):
|
||||
try:
|
||||
os.makedirs(self.sandbox_path, 0o700, exist_ok=True)
|
||||
os.makedirs(os.path.join(self.sandbox_path, 'execute'), 0o700, exist_ok=True)
|
||||
os.makedirs(os.path.join(self.sandbox_path, 'result'), 0o700, exist_ok=True)
|
||||
finally:
|
||||
os.umask(old_mask)
|
||||
# 只初始化一次
|
||||
fd = os.open(os.path.join(PROJECT_DIR, 'tmp', 'tool_executor_init_dir.lock'),
|
||||
os.O_CREAT | os.O_EXCL | os.O_WRONLY)
|
||||
os.close(fd)
|
||||
except FileExistsError:
|
||||
# 文件已存在 → 已初始化过
|
||||
return
|
||||
maxkb_logger.debug("init dir")
|
||||
if self.sandbox:
|
||||
try:
|
||||
os.system("chmod -R g-rwx /dev/shm /dev/mqueue")
|
||||
os.system("chmod o-rwx /run/postgresql")
|
||||
except Exception as e:
|
||||
maxkb_logger.warning(f'Exception: {e}', exc_info=True)
|
||||
pass
|
||||
if CONFIG.get("SANDBOX_TMP_DIR_ENABLED", '0') == "1":
|
||||
tmp_dir_path = os.path.join(self.sandbox_path, 'tmp')
|
||||
os.makedirs(tmp_dir_path, 0o700, exist_ok=True)
|
||||
os.system(f"chown -R {self.user}:root {tmp_dir_path}")
|
||||
if os.path.exists(self.sandbox_so_path):
|
||||
os.chmod(self.sandbox_so_path, 0o440)
|
||||
# 初始化host黑名单
|
||||
banned_hosts_file_path = f'{self.sandbox_path}/.SANDBOX_BANNED_HOSTS'
|
||||
if os.path.exists(banned_hosts_file_path):
|
||||
os.remove(banned_hosts_file_path)
|
||||
banned_hosts = CONFIG.get("SANDBOX_PYTHON_BANNED_HOSTS", '').strip()
|
||||
if banned_hosts:
|
||||
hostname = socket.gethostname()
|
||||
local_ip = socket.gethostbyname(hostname)
|
||||
banned_hosts = f"{banned_hosts},{hostname},{local_ip}"
|
||||
with open(banned_hosts_file_path, "w") as f:
|
||||
f.write(banned_hosts)
|
||||
os.chmod(banned_hosts_file_path, 0o440)
|
||||
|
||||
def exec_code(self, code_str, keywords, function_name=None):
|
||||
self.validate_banned_keywords(code_str)
|
||||
|
|
@ -57,9 +83,7 @@ class ToolExecutor:
|
|||
python_paths = CONFIG.get_sandbox_python_package_paths().split(',')
|
||||
_exec_code = f"""
|
||||
try:
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import sys, json, base64, builtins
|
||||
path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps']
|
||||
sys.path = [p for p in sys.path if p not in path_to_exclude]
|
||||
sys.path += {python_paths}
|
||||
|
|
@ -71,21 +95,21 @@ try:
|
|||
for local in locals_v:
|
||||
globals_v[local] = locals_v[local]
|
||||
exec_result=f(**keywords)
|
||||
with open({result_path!a}, 'w') as file:
|
||||
file.write(json.dumps({success}, default=str))
|
||||
builtins.print("\\n{_id}:"+base64.b64encode(json.dumps({success}, default=str).encode()).decode())
|
||||
except Exception as e:
|
||||
with open({result_path!a}, 'w') as file:
|
||||
file.write(json.dumps({err}))
|
||||
builtins.print("\\n{_id}:"+base64.b64encode(json.dumps({err}, default=str).encode()).decode())
|
||||
"""
|
||||
if self.sandbox:
|
||||
subprocess_result = self._exec_sandbox(_exec_code, _id)
|
||||
subprocess_result = self._exec_sandbox(_exec_code)
|
||||
else:
|
||||
subprocess_result = self._exec(_exec_code)
|
||||
if subprocess_result.returncode == 1:
|
||||
raise Exception(subprocess_result.stderr)
|
||||
with open(result_path, 'r') as file:
|
||||
result = json.loads(file.read())
|
||||
os.remove(result_path)
|
||||
lines = subprocess_result.stdout.splitlines()
|
||||
result_line = [line for line in lines if line.startswith(_id)]
|
||||
if not result_line:
|
||||
raise Exception("No result found.")
|
||||
result = json.loads(base64.b64decode(result_line[-1].split(":", 1)[1]).decode())
|
||||
if result.get('code') == 200:
|
||||
return result.get('data')
|
||||
raise Exception(result.get('msg'))
|
||||
|
|
@ -173,52 +197,45 @@ exec({dedent(code)!a})
|
|||
"""
|
||||
|
||||
def get_tool_mcp_config(self, code, params):
|
||||
code = self.generate_mcp_server_code(code, params)
|
||||
|
||||
_id = uuid.uuid7()
|
||||
code_path = f'{self.sandbox_path}/execute/{_id}.py'
|
||||
with open(code_path, 'w') as f:
|
||||
f.write(code)
|
||||
_code = self.generate_mcp_server_code(code, params)
|
||||
maxkb_logger.debug(f"Python code of mcp tool: {_code}")
|
||||
compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode()
|
||||
if self.sandbox:
|
||||
os.system(f"chown {self.user}:root {code_path}")
|
||||
|
||||
tool_config = {
|
||||
'command': 'su',
|
||||
'args': [
|
||||
'-s', sys.executable,
|
||||
'-c', f"exec(open('{code_path}', 'r').read())",
|
||||
'-c',
|
||||
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
|
||||
self.user,
|
||||
],
|
||||
'cwd': self.sandbox_path,
|
||||
'env': {
|
||||
'LD_PRELOAD': '/opt/maxkb-app/sandbox/sandbox.so',
|
||||
'SANDBOX_BANNED_HOSTS': self.banned_hosts,
|
||||
'LD_PRELOAD': self.sandbox_so_path,
|
||||
},
|
||||
'transport': 'stdio',
|
||||
}
|
||||
else:
|
||||
tool_config = {
|
||||
'command': sys.executable,
|
||||
'args': [code_path],
|
||||
'args': f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
|
||||
'transport': 'stdio',
|
||||
}
|
||||
return _id, tool_config
|
||||
return tool_config
|
||||
|
||||
def _exec_sandbox(self, _code, _id):
|
||||
exec_python_file = f'{self.sandbox_path}/execute/{_id}.py'
|
||||
with open(exec_python_file, 'w') as file:
|
||||
file.write(_code)
|
||||
os.system(f"chown {self.user}:root {exec_python_file}")
|
||||
def _exec_sandbox(self, _code):
|
||||
kwargs = {'cwd': BASE_DIR}
|
||||
kwargs['env'] = {
|
||||
'LD_PRELOAD': '/opt/maxkb-app/sandbox/sandbox.so',
|
||||
'SANDBOX_BANNED_HOSTS': self.banned_hosts,
|
||||
'LD_PRELOAD': self.sandbox_so_path,
|
||||
}
|
||||
maxkb_logger.debug(f"Sandbox execute code: {_code}")
|
||||
compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode()
|
||||
subprocess_result = subprocess.run(
|
||||
['su', '-s', python_directory, '-c', "exec(open('" + exec_python_file + "').read())", self.user],
|
||||
['su', '-s', python_directory, '-c',
|
||||
f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())',
|
||||
self.user],
|
||||
text=True,
|
||||
capture_output=True, **kwargs)
|
||||
os.remove(exec_python_file)
|
||||
return subprocess_result
|
||||
|
||||
def validate_banned_keywords(self, code_str):
|
||||
|
|
|
|||
|
|
@ -816,7 +816,7 @@ class DocumentSerializers(serializers.Serializer):
|
|||
|
||||
@post(post_function=post_embedding)
|
||||
@transaction.atomic
|
||||
def save(self, instance: Dict, with_valid=False, **kwargs):
|
||||
def save(self, instance: Dict, with_valid=True, **kwargs):
|
||||
if with_valid:
|
||||
DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True)
|
||||
self.is_valid(raise_exception=True)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
@desc:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import List
|
||||
|
||||
|
|
|
|||
|
|
@ -542,9 +542,9 @@ class DocumentView(APIView):
|
|||
|
||||
@extend_schema(
|
||||
methods=['PUT'],
|
||||
summary=_('Batch generate related documents'),
|
||||
description=_('Batch generate related documents'),
|
||||
operation_id=_('Batch generate related documents'), # type: ignore
|
||||
summary=_('Batch generate related problems'),
|
||||
description=_('Batch generate related problems'),
|
||||
operation_id=_('Batch generate related problems'), # type: ignore
|
||||
request=BatchGenerateRelatedAPI.get_request(),
|
||||
parameters=BatchGenerateRelatedAPI.get_parameters(),
|
||||
responses=BatchGenerateRelatedAPI.get_response(),
|
||||
|
|
@ -560,7 +560,7 @@ class DocumentView(APIView):
|
|||
[PermissionConstants.KNOWLEDGE.get_workspace_knowledge_permission()], CompareConstants.AND),
|
||||
)
|
||||
@log(
|
||||
menu='document', operate="Batch generate related documents",
|
||||
menu='document', operate="Batch generate related problems",
|
||||
get_operation_object=lambda r, keywords: get_knowledge_document_operation_object(
|
||||
get_knowledge_operation_object(keywords.get('knowledge_id')),
|
||||
get_document_operation_object_batch(r.data.get('document_id_list'))
|
||||
|
|
|
|||
|
|
@ -74,7 +74,6 @@ class ModelManage:
|
|||
|
||||
|
||||
def get_local_model(model, **kwargs):
|
||||
# system_setting = QuerySet(SystemSetting).filter(type=1).first()
|
||||
return LocalModelProvider().get_model(model.model_type, model.model_name,
|
||||
json.loads(
|
||||
rsa_long_decrypt(model.credential)),
|
||||
|
|
@ -111,6 +110,21 @@ class CompressDocuments(serializers.Serializer):
|
|||
query = serializers.CharField(required=True, label=_('query'))
|
||||
|
||||
|
||||
class ValidateModelSerializers(serializers.Serializer):
|
||||
model_name = serializers.CharField(required=True, label=_('model_name'))
|
||||
|
||||
model_type = serializers.CharField(required=True, label=_('model_type'))
|
||||
|
||||
model_credential = serializers.DictField(required=True, label="credential")
|
||||
|
||||
def validate_model(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'),
|
||||
self.data.get('model_credential'), model_params={},
|
||||
raise_exception=True)
|
||||
|
||||
|
||||
class ModelApplySerializers(serializers.Serializer):
|
||||
model_id = serializers.UUIDField(required=True, label=_('model id'))
|
||||
|
||||
|
|
@ -138,3 +152,9 @@ class ModelApplySerializers(serializers.Serializer):
|
|||
return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents(
|
||||
[Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in
|
||||
instance.get('documents')], instance.get('query'))]
|
||||
|
||||
def unload(self, with_valid=True):
|
||||
if with_valid:
|
||||
self.is_valid(raise_exception=True)
|
||||
ModelManage.delete_key(self.data.get('model_id'))
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ from . import views
|
|||
app_name = "local_model"
|
||||
# @formatter:off
|
||||
urlpatterns = [
|
||||
path('model/validate', views.LocalModelApply.Validate.as_view()),
|
||||
path('model/<str:model_id>/embed_documents', views.LocalModelApply.EmbedDocuments.as_view()),
|
||||
path('model/<str:model_id>/embed_query', views.LocalModelApply.EmbedQuery.as_view()),
|
||||
path('model/<str:model_id>/compress_documents', views.LocalModelApply.CompressDocuments.as_view()),
|
||||
path('model/<str:model_id>/unload', views.LocalModelApply.Unload.as_view()),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from urllib.request import Request
|
|||
from rest_framework.views import APIView
|
||||
|
||||
from common.result import result
|
||||
from local_model.serializers.model_apply_serializers import ModelApplySerializers
|
||||
from local_model.serializers.model_apply_serializers import ModelApplySerializers, ValidateModelSerializers
|
||||
|
||||
|
||||
class LocalModelApply(APIView):
|
||||
|
|
@ -32,3 +32,12 @@ class LocalModelApply(APIView):
|
|||
def post(self, request: Request, model_id):
|
||||
return result.success(
|
||||
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
|
||||
|
||||
class Unload(APIView):
|
||||
def post(self, request: Request, model_id):
|
||||
return result.success(
|
||||
ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data))
|
||||
|
||||
class Validate(APIView):
|
||||
def post(self, request: Request):
|
||||
return result.success(ValidateModelSerializers(data=request.data).validate_model())
|
||||
|
|
|
|||
|
|
@ -2945,7 +2945,7 @@ msgstr ""
|
|||
|
||||
#: apps/knowledge/views/document.py:466 apps/knowledge/views/document.py:467
|
||||
#: apps/knowledge/views/document.py:468
|
||||
msgid "Batch generate related documents"
|
||||
msgid "Batch generate related problems"
|
||||
msgstr ""
|
||||
|
||||
#: apps/knowledge/views/document.py:496 apps/knowledge/views/document.py:497
|
||||
|
|
@ -5126,7 +5126,7 @@ msgstr ""
|
|||
#: apps/resource_manage/views/document.py:261
|
||||
#: apps/resource_manage/views/document.py:262
|
||||
#: apps/resource_manage/views/document.py:263
|
||||
msgid "Batch generate related system document"
|
||||
msgid "Batch generate related system problems"
|
||||
msgstr ""
|
||||
|
||||
#: apps/resource_manage/views/document.py:290
|
||||
|
|
@ -5946,7 +5946,7 @@ msgstr ""
|
|||
#: apps/shared/views/shared_document.py:262
|
||||
#: apps/shared/views/shared_document.py:263
|
||||
#: apps/shared/views/shared_document.py:264
|
||||
msgid "Batch generate related shared documents"
|
||||
msgid "Batch generate related shared problems"
|
||||
msgstr ""
|
||||
|
||||
#: apps/shared/views/shared_document.py:291
|
||||
|
|
@ -6671,7 +6671,7 @@ msgid "Phone"
|
|||
msgstr ""
|
||||
|
||||
#: apps/users/serializers/user.py:120 apps/xpack/serializers/chat_user.py:68
|
||||
msgid "Username must be 4-20 characters long"
|
||||
msgid "Username must be 4-64 characters long"
|
||||
msgstr ""
|
||||
|
||||
#: apps/users/serializers/user.py:133 apps/users/serializers/user.py:298
|
||||
|
|
@ -8766,4 +8766,10 @@ msgid "Non-existent id"
|
|||
msgstr ""
|
||||
|
||||
msgid "No permission for the target folder"
|
||||
msgstr ""
|
||||
|
||||
msgid "Application token usage statistics"
|
||||
msgstr ""
|
||||
|
||||
msgid "Application top question statistics"
|
||||
msgstr ""
|
||||
|
|
@ -2952,8 +2952,8 @@ msgstr "批量刷新文档向量库"
|
|||
|
||||
#: apps/knowledge/views/document.py:466 apps/knowledge/views/document.py:467
|
||||
#: apps/knowledge/views/document.py:468
|
||||
msgid "Batch generate related documents"
|
||||
msgstr "批量生成相关文档"
|
||||
msgid "Batch generate related problems"
|
||||
msgstr "批量生成相关问题"
|
||||
|
||||
#: apps/knowledge/views/document.py:496 apps/knowledge/views/document.py:497
|
||||
#: apps/knowledge/views/document.py:498
|
||||
|
|
@ -5251,8 +5251,8 @@ msgstr "批量刷新文档向量库"
|
|||
#: apps/resource_manage/views/document.py:261
|
||||
#: apps/resource_manage/views/document.py:262
|
||||
#: apps/resource_manage/views/document.py:263
|
||||
msgid "Batch generate related system document"
|
||||
msgstr "批量生成相关文档"
|
||||
msgid "Batch generate related system problems"
|
||||
msgstr "批量生成相关问题"
|
||||
|
||||
#: apps/resource_manage/views/document.py:290
|
||||
#: apps/resource_manage/views/document.py:291
|
||||
|
|
@ -6071,8 +6071,8 @@ msgstr "批量刷新共享文档向量库"
|
|||
#: apps/shared/views/shared_document.py:262
|
||||
#: apps/shared/views/shared_document.py:263
|
||||
#: apps/shared/views/shared_document.py:264
|
||||
msgid "Batch generate related shared documents"
|
||||
msgstr "批量生成相关共享文档"
|
||||
msgid "Batch generate related shared problems"
|
||||
msgstr "批量生成相关问题"
|
||||
|
||||
#: apps/shared/views/shared_document.py:291
|
||||
#: apps/shared/views/shared_document.py:292
|
||||
|
|
@ -6796,8 +6796,8 @@ msgid "Phone"
|
|||
msgstr "手机"
|
||||
|
||||
#: apps/users/serializers/user.py:120 apps/xpack/serializers/chat_user.py:68
|
||||
msgid "Username must be 4-20 characters long"
|
||||
msgstr "用户名必须为4-20个字符"
|
||||
msgid "Username must be 4-64 characters long"
|
||||
msgstr "用户名必须为4-64个字符"
|
||||
|
||||
#: apps/users/serializers/user.py:133 apps/users/serializers/user.py:298
|
||||
#: apps/xpack/serializers/chat_user.py:81
|
||||
|
|
@ -8894,3 +8894,8 @@ msgstr "不存在的ID"
|
|||
msgid "No permission for the target folder"
|
||||
msgstr "没有目标文件夹的权限"
|
||||
|
||||
msgid "Application token usage statistics"
|
||||
msgstr "应用令牌使用统计"
|
||||
|
||||
msgid "Application top question statistics"
|
||||
msgstr "应用提问次数统计"
|
||||
|
|
|
|||
|
|
@ -2952,8 +2952,8 @@ msgstr "批量刷新文檔向量庫"
|
|||
|
||||
#: apps/knowledge/views/document.py:466 apps/knowledge/views/document.py:467
|
||||
#: apps/knowledge/views/document.py:468
|
||||
msgid "Batch generate related documents"
|
||||
msgstr "批量生成相關文檔"
|
||||
msgid "Batch generate related problems"
|
||||
msgstr "批量生成相關問題"
|
||||
|
||||
#: apps/knowledge/views/document.py:496 apps/knowledge/views/document.py:497
|
||||
#: apps/knowledge/views/document.py:498
|
||||
|
|
@ -5251,8 +5251,8 @@ msgstr "批量刷新文檔向量庫"
|
|||
#: apps/resource_manage/views/document.py:261
|
||||
#: apps/resource_manage/views/document.py:262
|
||||
#: apps/resource_manage/views/document.py:263
|
||||
msgid "Batch generate related system document"
|
||||
msgstr "批量生成相關文檔"
|
||||
msgid "Batch generate related system problems"
|
||||
msgstr "批量生成相關問題"
|
||||
|
||||
#: apps/resource_manage/views/document.py:290
|
||||
#: apps/resource_manage/views/document.py:291
|
||||
|
|
@ -6071,8 +6071,8 @@ msgstr "批量刷新共享文檔向量庫"
|
|||
#: apps/shared/views/shared_document.py:262
|
||||
#: apps/shared/views/shared_document.py:263
|
||||
#: apps/shared/views/shared_document.py:264
|
||||
msgid "Batch generate related shared documents"
|
||||
msgstr "批量生成相關共享文檔"
|
||||
msgid "Batch generate related shared problems"
|
||||
msgstr "批量生成相關問題"
|
||||
|
||||
#: apps/shared/views/shared_document.py:291
|
||||
#: apps/shared/views/shared_document.py:292
|
||||
|
|
@ -6796,8 +6796,8 @@ msgid "Phone"
|
|||
msgstr "手機"
|
||||
|
||||
#: apps/users/serializers/user.py:120 apps/xpack/serializers/chat_user.py:68
|
||||
msgid "Username must be 4-20 characters long"
|
||||
msgstr "用戶名必須為4-20個字符"
|
||||
msgid "Username must be 4-64 characters long"
|
||||
msgstr "用戶名必須為4-64個字符"
|
||||
|
||||
#: apps/users/serializers/user.py:133 apps/users/serializers/user.py:298
|
||||
#: apps/xpack/serializers/chat_user.py:81
|
||||
|
|
@ -8894,3 +8894,8 @@ msgstr "不存在的ID"
|
|||
msgid "No permission for the target folder"
|
||||
msgstr "沒有目標資料夾的權限"
|
||||
|
||||
msgid "Application token usage statistics"
|
||||
msgstr "應用令牌使用統計"
|
||||
|
||||
msgid "Application top question statistics"
|
||||
msgstr "應用提問次數統計"
|
||||
|
|
@ -6,16 +6,36 @@
|
|||
@date:2025/11/5 15:14
|
||||
@desc:
|
||||
"""
|
||||
import builtins
|
||||
import os
|
||||
import sys
|
||||
|
||||
from django.core.wsgi import get_wsgi_application
|
||||
|
||||
|
||||
class TorchBlocker:
|
||||
def __init__(self):
|
||||
self.original_import = builtins.__import__
|
||||
|
||||
def __call__(self, name, *args, **kwargs):
|
||||
if len([True for i in
|
||||
['torch']
|
||||
if
|
||||
i in name.lower()]) > 0:
|
||||
print(f"Disable package is being imported: 【{name}】", file=sys.stderr)
|
||||
pass
|
||||
else:
|
||||
return self.original_import(name, *args, **kwargs)
|
||||
|
||||
|
||||
# 安装导入拦截器
|
||||
builtins.__import__ = TorchBlocker()
|
||||
|
||||
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'maxkb.settings')
|
||||
|
||||
os.environ['TIKTOKEN_CACHE_DIR'] = '/opt/maxkb-app/model/tokenizer/openai-tiktoken-cl100k-base'
|
||||
application = get_wsgi_application()
|
||||
|
||||
|
||||
|
||||
def post_handler():
|
||||
from common.database_model_manage.database_model_manage import DatabaseModelManage
|
||||
from common import event
|
||||
|
|
@ -24,14 +44,5 @@ def post_handler():
|
|||
DatabaseModelManage.init()
|
||||
|
||||
|
||||
def post_scheduler_handler():
|
||||
from common import job
|
||||
|
||||
job.run()
|
||||
|
||||
# 启动后处理函数
|
||||
post_handler()
|
||||
|
||||
# 仅在scheduler中启动定时任务,dev local_model celery 不需要
|
||||
if os.environ.get('ENABLE_SCHEDULER') == '1':
|
||||
post_scheduler_handler()
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from common import forms
|
||||
|
|
@ -7,7 +6,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from django.utils.translation import gettext as _
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class AliyunBaiLianAsrSTTModelCredential(BaseForm, BaseModelCredential):
|
||||
api_url = forms.TextInputField(_('API URL'), required=True)
|
||||
|
|
@ -41,7 +40,7 @@ class AliyunBaiLianAsrSTTModelCredential(BaseForm, BaseModelCredential):
|
|||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/10/16 17:01
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -16,6 +15,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class BaiLianEmbeddingModelParams(BaseForm):
|
||||
dimensions = forms.SingleSelect(
|
||||
|
|
@ -69,7 +69,7 @@ class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential):
|
|||
model: AliyunBaiLianEmbedding = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query(_("Hello"))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,11 +6,9 @@
|
|||
@date:2024/7/11 18:41
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from common import forms
|
||||
from common.exception.app_exception import AppApiException
|
||||
|
|
@ -69,7 +67,7 @@ class QwenVLModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.check_auth(model_credential.get('api_key'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# coding=utf-8
|
||||
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -9,7 +8,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
|
||||
from common.forms.switch_field import SwitchField
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class QwenModelParams(BaseForm):
|
||||
"""
|
||||
|
|
@ -85,7 +84,7 @@ class ImageToVideoModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
|
@ -9,7 +8,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class BaiLianLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(
|
||||
|
|
@ -76,7 +75,7 @@ class BaiLianLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
else:
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from common import forms
|
||||
|
|
@ -7,6 +6,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.forms import BaseForm, PasswordInputField, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from django.utils.translation import gettext as _
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class AliyunBaiLianOmiSTTModelParams(BaseForm):
|
||||
CueWord = forms.TextInputField(
|
||||
|
|
@ -49,7 +49,7 @@ class AliyunBaiLianOmiSTTModelCredential(BaseForm, BaseModelCredential):
|
|||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# coding=utf-8
|
||||
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -10,7 +9,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.forms import BaseForm, PasswordInputField
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
|
||||
"""
|
||||
|
|
@ -60,7 +59,7 @@ class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential):
|
|||
model: AliyunBaiLianReranker = provider.get_model(model_type, model_name, model_credential)
|
||||
model.compress_documents([Document(page_content=_('Hello'))], _('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# coding=utf-8
|
||||
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -9,7 +8,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, PasswordInputField, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class AliyunBaiLianSTTModelParams(BaseForm):
|
||||
sample_rate = forms.SliderField(
|
||||
|
|
@ -68,7 +67,7 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential,**model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# coding=utf-8
|
||||
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -8,7 +7,7 @@ from django.utils.translation import gettext_lazy as _, gettext
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class QwenModelParams(BaseForm):
|
||||
"""
|
||||
|
|
@ -110,7 +109,7 @@ class QwenTextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# coding=utf-8
|
||||
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -8,7 +7,7 @@ from django.utils.translation import gettext_lazy as _, gettext
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class AliyunBaiLianTTSModelGeneralParams(BaseForm):
|
||||
"""
|
||||
|
|
@ -103,7 +102,7 @@ class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
# coding=utf-8
|
||||
|
||||
import traceback
|
||||
from typing import Dict, Any
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -9,7 +8,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel
|
||||
from common.forms.switch_field import SwitchField
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class QwenModelParams(BaseForm):
|
||||
"""
|
||||
|
|
@ -87,7 +86,7 @@ class TextToVideoModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -54,7 +54,6 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText):
|
|||
'callback': None,
|
||||
**self.params
|
||||
}
|
||||
print(recognition_params)
|
||||
recognition = Recognition(**recognition_params)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -54,7 +53,7 @@ class AnthropicImageModelCredential(BaseForm, BaseModelCredential):
|
|||
for chunk in res:
|
||||
maxkb_logger.info(chunk)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/11 18:32
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
|
@ -16,7 +15,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class AnthropicLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
|
|
@ -56,7 +55,7 @@ class AnthropicLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -8,7 +7,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
|
|
@ -35,7 +34,7 @@ class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
|
|||
except AppApiException:
|
||||
raise
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
_('Verification failed, please check whether the parameters are correct: {error}').format(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -8,7 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import ValidCode, BaseModelCredential
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class BedrockLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
|
|
@ -54,7 +53,7 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
except AppApiException:
|
||||
raise
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if raise_exception:
|
||||
raise AppApiException(ValidCode.valid_error.value,
|
||||
gettext(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/11 17:08
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -15,7 +14,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
|
|
@ -36,7 +35,7 @@ class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query(_('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import os
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
|
@ -57,7 +56,7 @@ class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
|||
for chunk in res:
|
||||
maxkb_logger.info(chunk)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/11 17:08
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
|
@ -17,7 +16,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class AzureLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
|
|
@ -68,7 +67,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException) or isinstance(e, BadRequestError):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -8,7 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
||||
api_version = forms.TextInputField("API Version", required=True)
|
||||
|
|
@ -32,7 +31,7 @@ class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -8,7 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class AzureOpenAITTIModelParams(BaseForm):
|
||||
size = forms.SingleSelect(
|
||||
|
|
@ -67,7 +66,7 @@ class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -8,7 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class AzureOpenAITTSModelGeneralParams(BaseForm):
|
||||
# alloy, echo, fable, onyx, nova, shimmer
|
||||
|
|
@ -50,7 +49,7 @@ class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -7,9 +7,11 @@
|
|||
@desc:
|
||||
"""
|
||||
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional, Any
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_openai import AzureChatOpenAI
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
|
@ -36,16 +38,42 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI):
|
|||
streaming=True,
|
||||
)
|
||||
|
||||
def get_last_generation_info(self) -> Optional[Dict[str, Any]]:
|
||||
return self.__dict__.get('_last_generation_info')
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
try:
|
||||
return super().get_num_tokens_from_messages(messages)
|
||||
return self.get_last_generation_info().get('input_tokens', 0)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
try:
|
||||
return super().get_num_tokens(text)
|
||||
return self.get_last_generation_info().get('output_tokens', 0)
|
||||
except Exception as e:
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return len(tokenizer.encode(text))
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
input: LanguageModelInput,
|
||||
config: Optional[RunnableConfig] = None,
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseMessage:
|
||||
message = super().invoke(input, config, stop=stop, **kwargs)
|
||||
if isinstance(message.content, str):
|
||||
return message
|
||||
elif isinstance(message.content, list):
|
||||
# 构造新的响应消息返回
|
||||
content = message.content
|
||||
normalized_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get('type') == 'text':
|
||||
normalized_parts.append(item.get('text', ''))
|
||||
message.content = ''.join(normalized_parts)
|
||||
self.__dict__.setdefault('_last_generation_info', {}).update(message.usage_metadata)
|
||||
return message
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from requests.exceptions import ConnectTimeout, ReadTimeout
|
||||
from typing import Dict, Optional, Any, Iterator, cast, Union, Sequence, Callable, Mapping
|
||||
|
|
@ -15,7 +16,7 @@ from langchain_openai import ChatOpenAI
|
|||
from langchain_openai.chat_models.base import _create_usage_metadata
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
def custom_get_token_ids(text: str):
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
|
|
@ -102,13 +103,13 @@ class BaseChatOpenAI(ChatOpenAI):
|
|||
future = executor.submit(super().get_num_tokens_from_messages, messages, tools)
|
||||
try:
|
||||
response = future.result()
|
||||
print("请求成功(未超时)")
|
||||
maxkb_logger.info("请求成功(未超时)")
|
||||
return response
|
||||
except Exception as e:
|
||||
if isinstance(e, ReadTimeout):
|
||||
raise # 继续抛出
|
||||
else:
|
||||
print("except:", e)
|
||||
maxkb_logger.error("except:", e)
|
||||
tokenizer = TokenizerManage.get_tokenizer()
|
||||
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
|
||||
|
||||
|
|
@ -211,3 +212,20 @@ class BaseChatOpenAI(ChatOpenAI):
|
|||
self.usage_metadata = chat_result.response_metadata[
|
||||
'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata
|
||||
return chat_result
|
||||
|
||||
|
||||
def upload_file_and_get_url(self, file_stream, file_name):
|
||||
"""上传文件并获取文件URL"""
|
||||
base64_video = base64.b64encode(file_stream).decode("utf-8")
|
||||
video_format = get_video_format(file_name)
|
||||
return f'data:{video_format};base64,{base64_video}'
|
||||
|
||||
def get_video_format(file_name):
|
||||
extension = file_name.split('.')[-1].lower()
|
||||
format_map = {
|
||||
'mp4': 'video/mp4',
|
||||
'avi': 'video/avi',
|
||||
'mov': 'video/mov',
|
||||
'wmv': 'video/x-ms-wmv'
|
||||
}
|
||||
return format_map.get(extension, 'video/mp4')
|
||||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/11 17:51
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -16,7 +15,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class DeepSeekLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
|
|
@ -56,7 +55,7 @@ class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/12 16:45
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -15,7 +14,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
|
|
@ -35,7 +34,7 @@ class GeminiEmbeddingCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query(_('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -53,7 +52,7 @@ class GeminiImageModelCredential(BaseForm, BaseModelCredential):
|
|||
for chunk in res:
|
||||
maxkb_logger.info(chunk)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/11 17:57
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -16,7 +15,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class GeminiLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
|
|
@ -56,7 +55,7 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -8,7 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
|
||||
api_key = forms.PasswordInputField('API Key', required=True)
|
||||
|
|
@ -30,7 +29,7 @@ class GeminiSTTModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -23,6 +23,5 @@ class GeminiImage(MaxKBBaseModel, ChatGoogleGenerativeAI):
|
|||
return GeminiImage(
|
||||
model=model_name,
|
||||
google_api_key=model_credential.get('api_key'),
|
||||
streaming=True,
|
||||
**optional_params,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/11 18:06
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -16,7 +15,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class KimiLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
|
|
@ -56,7 +55,7 @@ class KimiLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/11/7 14:02
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||
from .model import *
|
||||
else:
|
||||
from .web import *
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: embedding.py
|
||||
@date:2024/7/11 11:06
|
||||
@Author:虎虎
|
||||
@file: model.py.py
|
||||
@date:2025/11/7 14:02
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -16,7 +15,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from models_provider.impl.local_model_provider.model.embedding import LocalEmbedding
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
|
|
@ -35,7 +34,7 @@ class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
|
|||
model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query(gettext('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: web.py
|
||||
@date:2025/11/7 14:03
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common import forms
|
||||
from common.forms import BaseForm
|
||||
from maxkb.const import CONFIG
|
||||
from models_provider.base_model_provider import BaseModelCredential
|
||||
|
||||
|
||||
class LocalEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
prefix = CONFIG.get_admin_path()
|
||||
res = requests.post(
|
||||
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate',
|
||||
json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return result.get('data')
|
||||
raise Exception(result.get('message'))
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return model
|
||||
|
||||
cache_folder = forms.TextInputField(_('Model catalog'), required=True)
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: __init__.py.py
|
||||
@date:2025/11/7 14:22
|
||||
@desc:
|
||||
"""
|
||||
import os
|
||||
|
||||
if os.environ.get('SERVER_NAME', 'web') == 'local_model':
|
||||
from .model import *
|
||||
else:
|
||||
from .web import *
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: reranker.py
|
||||
@date:2024/9/3 14:33
|
||||
@Author:虎虎
|
||||
@file: model.py
|
||||
@date:2025/11/7 14:23
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
|
@ -17,7 +16,7 @@ from common.forms import BaseForm
|
|||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from models_provider.impl.local_model_provider.model.reranker import LocalReranker
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class LocalRerankerCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
|
|
@ -36,7 +35,7 @@ class LocalRerankerCredential(BaseForm, BaseModelCredential):
|
|||
model: LocalReranker = provider.get_model(model_type, model_name, model_credential)
|
||||
model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
@ -0,0 +1,37 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎虎
|
||||
@file: web.py
|
||||
@date:2025/11/7 14:23
|
||||
@desc:
|
||||
"""
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from common import forms
|
||||
from common.forms import BaseForm
|
||||
from maxkb.const import CONFIG
|
||||
from models_provider.base_model_provider import BaseModelCredential
|
||||
|
||||
|
||||
class LocalRerankerCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}'
|
||||
prefix = CONFIG.get_admin_path()
|
||||
res = requests.post(
|
||||
f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate',
|
||||
json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential})
|
||||
result = res.json()
|
||||
if result.get('code', 500) == 200:
|
||||
return result.get('data')
|
||||
raise Exception(result.get('message'))
|
||||
|
||||
def encryption_dict(self, model: Dict[str, object]):
|
||||
return model
|
||||
|
||||
cache_folder = forms.TextInputField(_('Model catalog'), required=True)
|
||||
|
|
@ -32,7 +32,6 @@ class LocalReranker(MaxKBBaseModel, BaseModel, BaseDocumentCompressor):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
print('ssss', kwargs.get('model_id', None))
|
||||
self.model_id = kwargs.get('model_id', None)
|
||||
|
||||
def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/12 16:45
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -15,6 +14,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class OpenAIEmbeddingModelParams(BaseForm):
|
||||
dimensions = forms.SingleSelect(
|
||||
|
|
@ -53,7 +53,7 @@ class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query(_('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import os
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
|
@ -56,7 +55,7 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential):
|
|||
for chunk in res:
|
||||
maxkb_logger.info(chunk)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/11 18:32
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -17,7 +16,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class OpenAILLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
|
|
@ -58,7 +57,7 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException) or isinstance(e, BadRequestError):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -8,7 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class OpenAISTTModelParams(BaseForm):
|
||||
language = forms.TextInputField(
|
||||
|
|
@ -38,7 +37,7 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -8,7 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class OpenAITTIModelParams(BaseForm):
|
||||
size = forms.SingleSelect(
|
||||
|
|
@ -70,7 +69,7 @@ class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -8,7 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class OpenAITTSModelGeneralParams(BaseForm):
|
||||
# alloy, echo, fable, onyx, nova, shimmer
|
||||
|
|
@ -49,7 +48,7 @@ class OpenAITTSModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/12 16:45
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -15,7 +14,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class RegoloEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
|
|
@ -35,7 +34,7 @@ class RegoloEmbeddingCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query(_('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import os
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
|
@ -12,7 +11,7 @@ from common.forms import BaseForm, TooltipLabel
|
|||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class RegoloImageModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
|
|
@ -53,10 +52,8 @@ class RegoloImageModelCredential(BaseForm, BaseModelCredential):
|
|||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])])
|
||||
for chunk in res:
|
||||
print(chunk)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/11 18:32
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -16,7 +15,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class RegoloLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
|
|
@ -57,7 +56,7 @@ class RegoloLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -8,7 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class RegoloTTIModelParams(BaseForm):
|
||||
size = forms.SingleSelect(
|
||||
|
|
@ -68,9 +67,8 @@ class RegoloTextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||
try:
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.check_auth()
|
||||
print(res)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/12 16:45
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -15,7 +14,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class SiliconCloudEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
|
|
@ -35,7 +34,7 @@ class SiliconCloudEmbeddingCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential)
|
||||
model.embed_query(_('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# coding=utf-8
|
||||
import base64
|
||||
import os
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
|
@ -56,7 +55,7 @@ class SiliconCloudImageModelCredential(BaseForm, BaseModelCredential):
|
|||
for chunk in res:
|
||||
maxkb_logger.info(chunk)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/11 18:32
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -16,7 +15,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class SiliconCloudLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
|
|
@ -57,7 +56,7 @@ class SiliconCloudLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/9/9 17:51
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -17,7 +16,7 @@ from common.exception.app_exception import AppApiException
|
|||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from models_provider.impl.siliconCloud_model_provider.model.reranker import SiliconCloudReranker
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
|
|
@ -36,7 +35,7 @@ class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential):
|
|||
model: SiliconCloudReranker = provider.get_model(model_type, model_name, model_credential)
|
||||
model.compress_documents([Document(page_content=_('Hello'))], _('Hello'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext as _
|
||||
|
|
@ -8,7 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class SiliconCloudSTTModelCredential(BaseForm, BaseModelCredential):
|
||||
api_base = forms.TextInputField('API URL', required=True)
|
||||
|
|
@ -31,7 +30,7 @@ class SiliconCloudSTTModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential,**model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -8,7 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class SiliconCloudTTIModelParams(BaseForm):
|
||||
size = forms.SingleSelect(
|
||||
|
|
@ -70,7 +69,7 @@ class SiliconCloudTextToImageModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
res = model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# coding=utf-8
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -8,6 +7,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class SiliconCloudTTSModelGeneralParams(BaseForm):
|
||||
# alloy, echo, fable, onyx, nova, shimmer
|
||||
|
|
@ -50,7 +50,7 @@ class SiliconCloudTTSModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.check_auth()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@date:2024/7/11 18:32
|
||||
@desc:
|
||||
"""
|
||||
import traceback
|
||||
from typing import Dict
|
||||
|
||||
from django.utils.translation import gettext_lazy as _, gettext
|
||||
|
|
@ -16,7 +15,7 @@ from common import forms
|
|||
from common.exception.app_exception import AppApiException
|
||||
from common.forms import BaseForm, TooltipLabel
|
||||
from models_provider.base_model_provider import BaseModelCredential, ValidCode
|
||||
|
||||
from common.utils.logger import maxkb_logger
|
||||
|
||||
class TencentCloudLLMModelParams(BaseForm):
|
||||
temperature = forms.SliderField(TooltipLabel(_('Temperature'),
|
||||
|
|
@ -57,7 +56,7 @@ class TencentCloudLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content=gettext('Hello'))])
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
maxkb_logger.error(f'Exception: {e}', exc_info=True)
|
||||
if isinstance(e, AppApiException):
|
||||
raise e
|
||||
if raise_exception:
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@
|
|||
@date:2024/4/18 15:28
|
||||
@desc:
|
||||
"""
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from typing import Dict
|
||||
|
||||
from common.config.tokenizer_manage_config import TokenizerManage
|
||||
from models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue