diff --git a/.github/workflows/build-and-push-vector-model.yml b/.github/workflows/build-and-push-vector-model.yml index 556a398b8..0c51e86f6 100644 --- a/.github/workflows/build-and-push-vector-model.yml +++ b/.github/workflows/build-and-push-vector-model.yml @@ -5,7 +5,7 @@ on: inputs: dockerImageTag: description: 'Docker Image Tag' - default: 'v2.0.2' + default: 'v2.0.3' required: true architecture: description: 'Architecture' diff --git a/.github/workflows/build-and-push.yml b/.github/workflows/build-and-push.yml index ec2bcf70e..647439edf 100644 --- a/.github/workflows/build-and-push.yml +++ b/.github/workflows/build-and-push.yml @@ -7,7 +7,7 @@ on: inputs: dockerImageTag: description: 'Image Tag' - default: 'v2.3.0-dev' + default: 'v2.4.0-dev' required: true dockerImageTagWithLatest: description: '是否发布latest tag(正式发版时选择,测试版本切勿选择)' diff --git a/.gitignore b/.gitignore index cc289d086..17f102d33 100644 --- a/.gitignore +++ b/.gitignore @@ -188,4 +188,5 @@ apps/models_provider/impl/*/icon/ apps/models_provider/impl/tencent_model_provider/credential/stt.py apps/models_provider/impl/tencent_model_provider/model/stt.py tmp/ -config.yml \ No newline at end of file +config.yml +.SANDBOX_BANNED_HOSTS \ No newline at end of file diff --git a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py index cd17d4a10..ed953e13c 100644 --- a/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py +++ b/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -93,6 +93,7 @@ def event_content(response, reasoning_content_chunk = chunk.additional_kwargs.get('reasoning_content', '') else: reasoning_content_chunk = reasoning_chunk.get('reasoning_content') + content_chunk = reasoning._normalize_content(content_chunk) all_text += content_chunk if reasoning_content_chunk is None: reasoning_content_chunk = '' @@ -191,23 +192,17 @@ class BaseChatStep(IChatStep): manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting, model_setting, - mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable) + mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, + mcp_output_enable) else: return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, paragraph_list, manage, padding_problem_text, chat_user_id, chat_user_type, no_references_setting, model_setting, - mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable) + mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, + mcp_output_enable) def get_details(self, manage, **kwargs): - # 删除临时生成的MCP代码文件 - if self.context.get('execute_ids'): - executor = ToolExecutor(CONFIG.get('SANDBOX')) - # 清理工具代码文件,延时删除,避免文件被占用 - for tool_id in self.context.get('execute_ids'): - code_path = f'{executor.sandbox_path}/execute/{tool_id}.py' - if os.path.exists(code_path): - os.remove(code_path) return { 'step_type': 'chat_step', 'run_time': self.context['run_time'], @@ -254,7 +249,6 @@ class BaseChatStep(IChatStep): if tool_enable: if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP self.context['tool_ids'] = tool_ids - self.context['execute_ids'] = [] for tool_id in tool_ids: tool = QuerySet(Tool).filter(id=tool_id).first() if tool is None or tool.is_active is False: @@ -264,9 +258,8 @@ class BaseChatStep(IChatStep): params = json.loads(rsa_long_decrypt(tool.init_params)) else: params = {} - _id, tool_config = executor.get_tool_mcp_config(tool.code, params) + tool_config = executor.get_tool_mcp_config(tool.code, params) - self.context['execute_ids'].append(_id) mcp_servers_config[str(tool.id)] = tool_config if len(mcp_servers_config) > 0: @@ -274,7 +267,6 @@ class BaseChatStep(IChatStep): return None - def get_stream_result(self, message_list: List[BaseMessage], chat_model: BaseChatModel = None, paragraph_list=None, @@ -304,7 +296,8 @@ class BaseChatStep(IChatStep): else: # 处理 MCP 请求 mcp_result = self._handle_mcp_request( - mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model, message_list, + mcp_enable, tool_enable, mcp_source, mcp_servers, mcp_tool_ids, tool_ids, mcp_output_enable, chat_model, + message_list, ) if mcp_result: return mcp_result, True @@ -329,7 +322,8 @@ class BaseChatStep(IChatStep): tool_ids=None, mcp_output_enable=True): chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, - no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, + no_references_setting, problem_text, mcp_enable, mcp_tool_ids, + mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable) chat_record_id = uuid.uuid7() r = StreamingHttpResponse( @@ -404,7 +398,9 @@ class BaseChatStep(IChatStep): # 调用模型 try: chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, - no_references_setting, problem_text, mcp_enable, mcp_tool_ids, mcp_servers, mcp_source, tool_enable, tool_ids, mcp_output_enable) + no_references_setting, problem_text, mcp_enable, + mcp_tool_ids, mcp_servers, mcp_source, tool_enable, + tool_ids, mcp_output_enable) if is_ai_chat: request_token = chat_model.get_num_tokens_from_messages(message_list) response_token = chat_model.get_num_tokens(chat_result.content) @@ -416,10 +412,10 @@ class BaseChatStep(IChatStep): reasoning_result_end = reasoning.get_end_reasoning_content() content = reasoning_result.get('content') + reasoning_result_end.get('content') if 'reasoning_content' in chat_result.response_metadata: - reasoning_content = chat_result.response_metadata.get('reasoning_content', '') + reasoning_content = (chat_result.response_metadata.get('reasoning_content', '') or '') else: - reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get( - 'reasoning_content') + reasoning_content = (reasoning_result.get('reasoning_content') or "") + (reasoning_result_end.get( + 'reasoning_content') or "") post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, content, manage, self, padding_problem_text, reasoning_content=reasoning_content) diff --git a/apps/application/flow/i_step_node.py b/apps/application/flow/i_step_node.py index 1eee330db..c30491140 100644 --- a/apps/application/flow/i_step_node.py +++ b/apps/application/flow/i_step_node.py @@ -95,6 +95,7 @@ class WorkFlowPostHandler: application_public_access_client.access_num = application_public_access_client.access_num + 1 application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1 application_public_access_client.save() + self.chat_info = None class KnowledgeWorkflowPostHandler(WorkFlowPostHandler): diff --git a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py index 20dc22b25..630b2e9b6 100644 --- a/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py +++ b/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -108,9 +108,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor content = reasoning_result.get('content') + reasoning_result_end.get('content') meta = {**response.response_metadata, **response.additional_kwargs} if 'reasoning_content' in meta: - reasoning_content = meta.get('reasoning_content', '') + reasoning_content = (meta.get('reasoning_content', '') or '') else: - reasoning_content = reasoning_result.get('reasoning_content') + reasoning_result_end.get('reasoning_content') + reasoning_content = (reasoning_result.get('reasoning_content') or '') + (reasoning_result_end.get('reasoning_content') or '') _write_context(node_variable, workflow_variable, node, workflow, content, reasoning_content) @@ -233,7 +233,6 @@ class BaseChatNode(IChatNode): if tool_enable: if tool_ids and len(tool_ids) > 0: # 如果有工具ID,则将其转换为MCP self.context['tool_ids'] = tool_ids - self.context['execute_ids'] = [] for tool_id in tool_ids: tool = QuerySet(Tool).filter(id=tool_id).first() if not tool.is_active: @@ -243,9 +242,8 @@ class BaseChatNode(IChatNode): params = json.loads(rsa_long_decrypt(tool.init_params)) else: params = {} - _id, tool_config = executor.get_tool_mcp_config(tool.code, params) + tool_config = executor.get_tool_mcp_config(tool.code, params) - self.context['execute_ids'].append(_id) mcp_servers_config[str(tool.id)] = tool_config if len(mcp_servers_config) > 0: @@ -307,14 +305,6 @@ class BaseChatNode(IChatNode): return result def get_details(self, index: int, **kwargs): - # 删除临时生成的MCP代码文件 - if self.context.get('execute_ids'): - executor = ToolExecutor(CONFIG.get('SANDBOX')) - # 清理工具代码文件,延时删除,避免文件被占用 - for tool_id in self.context.get('execute_ids'): - code_path = f'{executor.sandbox_path}/execute/{tool_id}.py' - if os.path.exists(code_path): - os.remove(code_path) return { 'name': self.node.properties.get('stepName'), "index": index, diff --git a/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py b/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py index ceba18eee..aa146cea2 100644 --- a/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py +++ b/apps/application/flow/step_node/image_to_video_step_node/impl/base_image_to_video_node.py @@ -74,7 +74,7 @@ class BaseImageToVideoNode(IImageToVideoNode): def get_file_base64(self, image_url): try: if isinstance(image_url, list): - image_url = image_url[0].get('file_id') + image_url = image_url[0].get('file_id') if 'file_id' in image_url[0] else image_url[0].get('url') if isinstance(image_url, str) and not image_url.startswith('http'): file = QuerySet(File).filter(id=image_url).first() file_bytes = file.get_bytes() diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index 7f57b4340..f93e670c1 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -131,11 +131,18 @@ class BaseImageUnderstandNode(IImageUnderstandNode): image_list = data['image_list'] if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': return HumanMessage(content=chat_record.problem_text) - file_id_list = [image.get('file_id') for image in image_list] + + file_id_list = [] + url_list = [] + for image in image_list: + if 'file_id' in image: + file_id_list.append(image.get('file_id')) + elif 'url' in image: + url_list.append(image.get('url')) return HumanMessage(content=[ {'type': 'text', 'text': data['question']}, - *[{'type': 'image_url', 'image_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list] - + *[{'type': 'image_url', 'image_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list], + *[{'type': 'image_url', 'image_url': {'url': url}} for url in url_list] ]) return HumanMessage(content=chat_record.problem_text) @@ -155,13 +162,22 @@ class BaseImageUnderstandNode(IImageUnderstandNode): image_list = data['image_list'] if len(image_list) == 0 or data['dialogue_type'] == 'WORKFLOW': return HumanMessage(content=chat_record.problem_text) - image_base64_list = [file_id_to_base64(image.get('file_id')) for image in image_list] + file_id_list = [] + url_list = [] + for image in image_list: + if 'file_id' in image: + file_id_list.append(image.get('file_id')) + elif 'url' in image: + url_list.append(image.get('url')) + image_base64_list = [file_id_to_base64(file_id) for file_id in file_id_list] + return HumanMessage( content=[ {'type': 'text', 'text': data['question']}, *[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for - base64_image in image_base64_list] + base64_image in image_base64_list], + *[{'type': 'image_url', 'image_url': url} for url in url_list] ]) return HumanMessage(content=chat_record.problem_text) @@ -177,13 +193,17 @@ class BaseImageUnderstandNode(IImageUnderstandNode): images.append({'type': 'image_url', 'image_url': {'url': image}}) elif image is not None and len(image) > 0: for img in image: - file_id = img['file_id'] - file = QuerySet(File).filter(id=file_id).first() - image_bytes = file.get_bytes() - base64_image = base64.b64encode(image_bytes).decode("utf-8") - image_format = what(None, image_bytes) - images.append( - {'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}}) + if 'file_id' in img: + file_id = img['file_id'] + file = QuerySet(File).filter(id=file_id).first() + image_bytes = file.get_bytes() + base64_image = base64.b64encode(image_bytes).decode("utf-8") + image_format = what(None, image_bytes) + images.append( + {'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}}) + elif 'url' in img and img['url'].startswith('http'): + images.append( + {'type': 'image_url', 'image_url': {'url': img["url"]}}) return images def generate_message_list(self, image_model, system: str, prompt: str, history_message, image): diff --git a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py index 8144512f0..92083a6da 100644 --- a/apps/application/flow/step_node/loop_node/impl/base_loop_node.py +++ b/apps/application/flow/step_node/loop_node/impl/base_loop_node.py @@ -226,11 +226,14 @@ class BaseLoopNode(ILoopNode): def save_context(self, details, workflow_manage): self.context['loop_context_data'] = details.get('loop_context_data') self.context['loop_answer_data'] = details.get('loop_answer_data') + self.context['loop_node_data'] = details.get('loop_node_data') self.context['result'] = details.get('result') self.context['params'] = details.get('params') self.context['run_time'] = details.get('run_time') self.context['index'] = details.get('current_index') self.context['item'] = details.get('current_item') + for key, value in (details.get('loop_context_data') or {}).items(): + self.context[key] = value self.answer_text = "" def get_answer_list(self) -> List[Answer] | None: diff --git a/apps/application/flow/step_node/text_to_video_step_node/impl/base_text_to_video_node.py b/apps/application/flow/step_node/text_to_video_step_node/impl/base_text_to_video_node.py index 2d8e3fc4c..9d1dba37d 100644 --- a/apps/application/flow/step_node/text_to_video_step_node/impl/base_text_to_video_node.py +++ b/apps/application/flow/step_node/text_to_video_step_node/impl/base_text_to_video_node.py @@ -37,7 +37,6 @@ class BaseTextToVideoNode(ITextToVideoNode): self.context['dialogue_type'] = dialogue_type self.context['negative_prompt'] = self.generate_prompt_question(negative_prompt) video_urls = ttv_model.generate_video(question, negative_prompt) - print('video_urls', video_urls) # 保存图片 if video_urls is None: return NodeResult({'answer': gettext('Failed to generate video')}, {}) diff --git a/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py b/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py index 2118e6a9f..9a478e6c9 100644 --- a/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py +++ b/apps/application/flow/step_node/video_understand_step_node/impl/base_video_understand_node.py @@ -1,9 +1,7 @@ # coding=utf-8 -import base64 -import mimetypes + import time from functools import reduce -from imghdr import what from typing import List, Dict from django.db.models import QuerySet @@ -12,7 +10,6 @@ from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage, AI from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.video_understand_step_node.i_video_understand_node import IVideoUnderstandNode from knowledge.models import File -from models_provider.impl.volcanic_engine_model_provider.model.image import get_video_format from models_provider.tools import get_model_instance_by_model_workspace_id @@ -134,11 +131,17 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode): # 增加对 None 和空列表的检查 if not video_list or len(video_list) == 0 or data['dialogue_type'] == 'WORKFLOW': return HumanMessage(content=chat_record.problem_text) - file_id_list = [video.get('file_id') for video in video_list] + file_id_list = [] + url_list = [] + for image in video_list: + if 'file_id' in image: + file_id_list.append(image.get('file_id')) + elif 'url' in image: + url_list.append(image.get('url')) return HumanMessage(content=[ {'type': 'text', 'text': data['question']}, - *[{'type': 'video_url', 'video_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list] - + *[{'type': 'video_url', 'video_url': {'url': f'./oss/file/{file_id}'}} for file_id in file_id_list], + *[{'type': 'video_url', 'video_url': {'url': url}} for url in url_list], ]) return HumanMessage(content=chat_record.problem_text) @@ -158,6 +161,13 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode): video_list = data['video_list'] if len(video_list) == 0 or data['dialogue_type'] == 'WORKFLOW': return HumanMessage(content=chat_record.problem_text) + file_id_list = [] + url_list = [] + for image in video_list: + if 'file_id' in image: + file_id_list.append(image.get('file_id')) + elif 'url' in image: + url_list.append(image.get('url')) video_base64_list = [file_id_to_base64(video.get('file_id'), video_model) for video in video_list] return HumanMessage( content=[ @@ -177,11 +187,15 @@ class BaseVideoUnderstandNode(IVideoUnderstandNode): videos.append({'type': 'video_url', 'video_url': {'url': image}}) elif image is not None and len(image) > 0: for img in image: - file_id = img['file_id'] - file = QuerySet(File).filter(id=file_id).first() - url = video_model.upload_file_and_get_url(file.get_bytes(), file.file_name) - videos.append( - {'type': 'video_url', 'video_url': {'url': url}}) + if 'file_id' in img: + file_id = img['file_id'] + file = QuerySet(File).filter(id=file_id).first() + url = video_model.upload_file_and_get_url(file.get_bytes(), file.file_name) + videos.append( + {'type': 'video_url', 'video_url': {'url': url}}) + elif 'url' in img and img['url'].startswith('http'): + videos.append( + {'type': 'video_url', 'video_url': {'url': img['url']}}) return videos def generate_message_list(self, video_model, system: str, prompt: str, history_message, video): diff --git a/apps/application/flow/tools.py b/apps/application/flow/tools.py index 527b11ce8..149eafaa0 100644 --- a/apps/application/flow/tools.py +++ b/apps/application/flow/tools.py @@ -8,7 +8,8 @@ """ import asyncio import json -import traceback +import queue +import threading from typing import Iterator from django.http import StreamingHttpResponse @@ -47,6 +48,21 @@ class Reasoning: return r return {'content': '', 'reasoning_content': ''} + def _normalize_content(self, content): + """将不同类型的内容统一转换为字符串""" + if isinstance(content, str): + return content + elif isinstance(content, list): + # 处理包含多种内容类型的列表 + normalized_parts = [] + for item in content: + if isinstance(item, dict): + if item.get('type') == 'text': + normalized_parts.append(item.get('text', '')) + return ''.join(normalized_parts) + else: + return str(content) + def get_reasoning_content(self, chunk): # 如果没有开始思考过程标签那么就全是结果 if self.reasoning_content_start_tag is None or len(self.reasoning_content_start_tag) == 0: @@ -55,6 +71,7 @@ class Reasoning: # 如果没有结束思考过程标签那么就全部是思考过程 if self.reasoning_content_end_tag is None or len(self.reasoning_content_end_tag) == 0: return {'content': '', 'reasoning_content': chunk.content} + chunk.content = self._normalize_content(chunk.content) self.all_content += chunk.content if not self.reasoning_content_is_start and len(self.all_content) >= self.reasoning_content_start_tag_len: if self.all_content.startswith(self.reasoning_content_start_tag): @@ -201,17 +218,6 @@ def to_stream_response_simple(stream_event): r['Cache-Control'] = 'no-cache' return r -tool_message_template = """ -
- - Called MCP Tool: %s - - -%s - -
- -""" tool_message_json_template = """ ```json @@ -219,42 +225,142 @@ tool_message_json_template = """ ``` """ +tool_message_complete_template = """ +
+ + Called MCP Tool: %s + -def generate_tool_message_template(name, context): - if '```' in context: - return tool_message_template % (name, context) +**Input:** +%s + +**Output:** +%s + +
+ +""" + + +def generate_tool_message_complete(name, input_content, output_content): + """生成包含输入和输出的工具消息模版""" + # 格式化输入 + if '```' not in input_content: + input_formatted = tool_message_json_template % input_content else: - return tool_message_template % (name, tool_message_json_template % (context)) + input_formatted = input_content + + # 格式化输出 + if '```' not in output_content: + output_formatted = tool_message_json_template % output_content + else: + output_formatted = output_content + + return tool_message_complete_template % (name, input_formatted, output_formatted) + + +# 全局单例事件循环 +_global_loop = None +_loop_thread = None +_loop_lock = threading.Lock() + + +def get_global_loop(): + """获取全局共享的事件循环""" + global _global_loop, _loop_thread + + with _loop_lock: + if _global_loop is None: + _global_loop = asyncio.new_event_loop() + + def run_forever(): + asyncio.set_event_loop(_global_loop) + _global_loop.run_forever() + + _loop_thread = threading.Thread(target=run_forever, daemon=True, name="GlobalAsyncLoop") + _loop_thread.start() + + return _global_loop async def _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable=True): - client = MultiServerMCPClient(json.loads(mcp_servers)) - tools = await client.get_tools() - agent = create_react_agent(chat_model, tools) - response = agent.astream({"messages": message_list}, stream_mode='messages') - async for chunk in response: - if mcp_output_enable and isinstance(chunk[0], ToolMessage): - content = generate_tool_message_template(chunk[0].name, chunk[0].content) - chunk[0].content = content - yield chunk[0] - if isinstance(chunk[0], AIMessageChunk): - yield chunk[0] + try: + client = MultiServerMCPClient(json.loads(mcp_servers)) + tools = await client.get_tools() + agent = create_react_agent(chat_model, tools) + response = agent.astream({"messages": message_list}, stream_mode='messages') + + # 用于存储工具调用信息 + tool_calls_info = {} + + async for chunk in response: + if isinstance(chunk[0], AIMessageChunk): + tool_calls = chunk[0].additional_kwargs.get('tool_calls', []) + for tool_call in tool_calls: + tool_id = tool_call.get('id', '') + if tool_id: + # 保存工具调用的输入 + tool_calls_info[tool_id] = { + 'name': tool_call.get('function', {}).get('name', ''), + 'input': tool_call.get('function', {}).get('arguments', '') + } + yield chunk[0] + + if mcp_output_enable and isinstance(chunk[0], ToolMessage): + tool_id = chunk[0].tool_call_id + if tool_id in tool_calls_info: + # 合并输入和输出 + tool_info = tool_calls_info[tool_id] + content = generate_tool_message_complete( + tool_info['name'], + tool_info['input'], + chunk[0].content + ) + chunk[0].content = content + yield chunk[0] + + except ExceptionGroup as eg: + + def get_real_error(exc): + if isinstance(exc, ExceptionGroup): + return get_real_error(exc.exceptions[0]) + return exc + + real_error = get_real_error(eg) + error_msg = f"{type(real_error).__name__}: {str(real_error)}" + raise RuntimeError(error_msg) from None + + except Exception as e: + error_msg = f"{type(e).__name__}: {str(e)}" + raise RuntimeError(error_msg) from None def mcp_response_generator(chat_model, message_list, mcp_servers, mcp_output_enable=True): - loop = asyncio.new_event_loop() - try: - async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable) - while True: - try: - chunk = loop.run_until_complete(anext_async(async_gen)) - yield chunk - except StopAsyncIteration: - break - except Exception as e: - maxkb_logger.error(f'Exception: {e}', exc_info=True) - finally: - loop.close() + """使用全局事件循环,不创建新实例""" + result_queue = queue.Queue() + loop = get_global_loop() # 使用共享循环 + + async def _run(): + try: + async_gen = _yield_mcp_response(chat_model, message_list, mcp_servers, mcp_output_enable) + async for chunk in async_gen: + result_queue.put(('data', chunk)) + except Exception as e: + maxkb_logger.error(f'Exception: {e}', exc_info=True) + result_queue.put(('error', e)) + finally: + result_queue.put(('done', None)) + + # 在全局循环中调度任务 + asyncio.run_coroutine_threadsafe(_run(), loop) + + while True: + msg_type, data = result_queue.get() + if msg_type == 'done': + break + if msg_type == 'error': + raise data + yield data async def anext_async(agen): diff --git a/apps/application/flow/workflow_manage.py b/apps/application/flow/workflow_manage.py index 78a458b64..2bec10326 100644 --- a/apps/application/flow/workflow_manage.py +++ b/apps/application/flow/workflow_manage.py @@ -9,7 +9,6 @@ import concurrent import json import threading -import traceback from concurrent.futures import ThreadPoolExecutor from functools import reduce from typing import List, Dict @@ -26,6 +25,7 @@ from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult, from application.flow.step_node import get_node from common.handle.base_to_response import BaseToResponse from common.handle.impl.response.system_to_response import SystemToResponse +from common.utils.logger import maxkb_logger executor = ThreadPoolExecutor(max_workers=200) @@ -234,25 +234,62 @@ class WorkflowManage: 非流式响应 @return: 结果 """ - self.run_chain_async(None, None, language) - while self.is_run(): - pass - details = self.get_runtime_details() - message_tokens = sum([row.get('message_tokens') for row in details.values() if - 'message_tokens' in row and row.get('message_tokens') is not None]) - answer_tokens = sum([row.get('answer_tokens') for row in details.values() if - 'answer_tokens' in row and row.get('answer_tokens') is not None]) - answer_text_list = self.get_answer_text_list() - answer_text = '\n\n'.join( - '\n\n'.join([a.get('content') for a in answer]) for answer in - answer_text_list) - answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, []) - self.work_flow_post_handler.handler(self) - return self.base_to_response.to_block_response(self.params['chat_id'], - self.params['chat_record_id'], answer_text, True - , message_tokens, answer_tokens, - _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR, - other_params={'answer_list': answer_list}) + try: + self.run_chain_async(None, None, language) + while self.is_run(): + pass + details = self.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + answer_text_list = self.get_answer_text_list() + answer_text = '\n\n'.join( + '\n\n'.join([a.get('content') for a in answer]) for answer in + answer_text_list) + answer_list = reduce(lambda pre, _n: [*pre, *_n], answer_text_list, []) + self.work_flow_post_handler.handler(self) + + res = self.base_to_response.to_block_response(self.params['chat_id'], + self.params['chat_record_id'], answer_text, True + , message_tokens, answer_tokens, + _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR, + other_params={'answer_list': answer_list}) + finally: + self._cleanup() + return res + + def _cleanup(self): + """清理所有对象引用""" + # 清理列表 + self.future_list.clear() + self.field_list.clear() + self.global_field_list.clear() + self.chat_field_list.clear() + self.image_list.clear() + self.video_list.clear() + self.document_list.clear() + self.audio_list.clear() + self.other_list.clear() + if hasattr(self, 'node_context'): + self.node_context.clear() + + # 清理字典 + self.context.clear() + self.chat_context.clear() + self.form_data.clear() + + # 清理对象引用 + self.node_chunk_manage = None + self.work_flow_post_handler = None + self.flow = None + self.start_node = None + self.current_node = None + self.current_result = None + self.chat_record = None + self.base_to_response = None + self.params = None + self.lock = None def run_stream(self, current_node, node_result_future, language='zh'): """ @@ -307,6 +344,7 @@ class WorkflowManage: '', [], '', True, message_tokens, answer_tokens, {}) + self._cleanup() def run_chain_async(self, current_node, node_result_future, language='zh'): future = executor.submit(self.run_chain_manage, current_node, node_result_future, language) @@ -345,7 +383,7 @@ class WorkflowManage: current_node, node_result_future) return result except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) return None def hand_node_result(self, current_node, node_result_future): @@ -357,7 +395,7 @@ class WorkflowManage: list(result) return current_result except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) self.status = 500 current_node.get_write_error_context(e) self.answer += str(e) @@ -435,9 +473,10 @@ class WorkflowManage: return current_result except Exception as e: # 添加节点 - traceback.print_exc() + + maxkb_logger.error(f'Exception: {e}', exc_info=True) chunk = self.base_to_response.to_stream_chunk_response(self.params.get('chat_id'), - self.params.get('chat_record_id'), + self.params.get('chat_id'), current_node.id, current_node.up_node_id_list, 'Exception:' + str(e), False, 0, 0, diff --git a/apps/application/serializers/application_stats.py b/apps/application/serializers/application_stats.py index 3be5f211f..8d45a5311 100644 --- a/apps/application/serializers/application_stats.py +++ b/apps/application/serializers/application_stats.py @@ -118,3 +118,37 @@ class ApplicationStatisticsSerializer(serializers.Serializer): days.append(current_date.strftime('%Y-%m-%d')) current_date += datetime.timedelta(days=1) return days + + def get_token_usage_statistics(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + start_time = self.get_start_time() + end_time = self.get_end_time() + get_token_usage = native_search( + {'default_sql': QuerySet(model=get_dynamics_model( + {'application_chat.application_id': models.UUIDField(), + 'application_chat_record.create_time': models.DateTimeField()})).filter( + **{'application_chat.application_id': self.data.get('application_id'), + 'application_chat_record.create_time__gte': start_time, + 'application_chat_record.create_time__lte': end_time} + )}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'get_token_usage.sql'))) + return get_token_usage + + def get_top_questions_statistics(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + start_time = self.get_start_time() + end_time = self.get_end_time() + get_top_questions = native_search( + {'default_sql': QuerySet(model=get_dynamics_model( + {'application_chat.application_id': models.UUIDField(), + 'application_chat_record.create_time': models.DateTimeField()})).filter( + **{'application_chat.application_id': self.data.get('application_id'), + 'application_chat_record.create_time__gte': start_time, + 'application_chat_record.create_time__lte': end_time} + )}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'top_questions.sql'))) + return get_top_questions \ No newline at end of file diff --git a/apps/application/serializers/common.py b/apps/application/serializers/common.py index 10d7efb2e..5c60512ad 100644 --- a/apps/application/serializers/common.py +++ b/apps/application/serializers/common.py @@ -6,7 +6,7 @@ @date:2025/6/9 13:42 @desc: """ -from datetime import datetime +import json from typing import List from django.core.cache import cache @@ -20,6 +20,7 @@ from application.models import Application, ChatRecord, Chat, ApplicationVersion from application.serializers.application_chat import ChatCountSerializer from common.constants.cache_version import Cache_Version from common.database_model_manage.database_model_manage import DatabaseModelManage +from common.encoder.encoder import SystemEncoder from common.exception.app_exception import ChatException from knowledge.models import Document from models_provider.models import Model @@ -226,10 +227,69 @@ class ChatInfo: chat_record.save() ChatCountSerializer(data={'chat_id': self.chat_id}).update_chat() + def to_dict(self): + + return { + 'chat_id': self.chat_id, + 'chat_user_id': self.chat_user_id, + 'chat_user_type': self.chat_user_type, + 'knowledge_id_list': self.knowledge_id_list, + 'exclude_document_id_list': self.exclude_document_id_list, + 'application_id': self.application_id, + 'chat_record_list': [self.chat_record_to_map(c) for c in self.chat_record_list][-20:], + 'debug': self.debug + } + + def chat_record_to_map(self, chat_record): + return {'id': chat_record.id, + 'chat_id': chat_record.chat_id, + 'vote_status': chat_record.vote_status, + 'problem_text': chat_record.problem_text, + 'answer_text': chat_record.answer_text, + 'answer_text_list': chat_record.answer_text_list, + 'message_tokens': chat_record.message_tokens, + 'answer_tokens': chat_record.answer_tokens, + 'const': chat_record.const, + 'details': chat_record.details, + 'improve_paragraph_id_list': chat_record.improve_paragraph_id_list, + 'run_time': chat_record.run_time, + 'index': chat_record.index} + + @staticmethod + def map_to_chat_record(chat_record_dict): + return ChatRecord(id=chat_record_dict.get('id'), + chat_id=chat_record_dict.get('chat_id'), + vote_status=chat_record_dict.get('vote_status'), + problem_text=chat_record_dict.get('problem_text'), + answer_text=chat_record_dict.get('answer_text'), + answer_text_list=chat_record_dict.get('answer_text_list'), + message_tokens=chat_record_dict.get('message_tokens'), + answer_tokens=chat_record_dict.get('answer_tokens'), + const=chat_record_dict.get('const'), + details=chat_record_dict.get('details'), + improve_paragraph_id_list=chat_record_dict.get('improve_paragraph_id_list'), + run_time=chat_record_dict.get('run_time'), + index=chat_record_dict.get('index'), ) + def set_cache(self): - cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self, version=Cache_Version.CHAT.get_version(), + cache.set(Cache_Version.CHAT.get_key(key=self.chat_id), self.to_dict(), + version=Cache_Version.CHAT_INFO.get_version(), timeout=60 * 30) + @staticmethod + def map_to_chat_info(chat_info_dict): + c = ChatInfo(chat_info_dict.get('chat_id'), chat_info_dict.get('chat_user_id'), + chat_info_dict.get('chat_user_type'), chat_info_dict.get('knowledge_id_list'), + chat_info_dict.get('exclude_document_id_list'), + chat_info_dict.get('application_id'), + debug=chat_info_dict.get('debug')) + c.chat_record_list = [ChatInfo.map_to_chat_record(c_r) for c_r in chat_info_dict.get('chat_record_list')] + return c + @staticmethod def get_cache(chat_id): - return cache.get(Cache_Version.CHAT.get_key(key=chat_id), version=Cache_Version.CHAT.get_version()) + chat_info_dict = cache.get(Cache_Version.CHAT.get_key(key=chat_id), + version=Cache_Version.CHAT_INFO.get_version()) + if chat_info_dict: + return ChatInfo.map_to_chat_info(chat_info_dict) + return None diff --git a/apps/application/sql/get_token_usage.sql b/apps/application/sql/get_token_usage.sql new file mode 100644 index 000000000..848a36430 --- /dev/null +++ b/apps/application/sql/get_token_usage.sql @@ -0,0 +1,12 @@ + +SELECT + SUM(application_chat_record.message_tokens + application_chat_record.answer_tokens) as "token_usage", + COALESCE(application_chat.asker->>'username', '游客') as "username" +FROM + application_chat_record application_chat_record + LEFT JOIN application_chat application_chat ON application_chat."id" = application_chat_record.chat_id +${default_sql} +GROUP BY + COALESCE(application_chat.asker->>'username', '游客') +ORDER BY + "token_usage" DESC diff --git a/apps/application/sql/top_questions.sql b/apps/application/sql/top_questions.sql new file mode 100644 index 000000000..5997934ec --- /dev/null +++ b/apps/application/sql/top_questions.sql @@ -0,0 +1,11 @@ +SELECT COUNT(application_chat_record."id") AS chat_record_count, + COALESCE(application_chat.asker ->>'username', '游客') AS username +FROM application_chat_record application_chat_record + LEFT JOIN application_chat application_chat ON application_chat."id" = application_chat_record.chat_id + ${default_sql} +GROUP BY + COALESCE (application_chat.asker->>'username', '游客') +ORDER BY + chat_record_count DESC, + username ASC + diff --git a/apps/application/urls.py b/apps/application/urls.py index fb6a8ac3d..34ded9fe0 100644 --- a/apps/application/urls.py +++ b/apps/application/urls.py @@ -13,6 +13,8 @@ urlpatterns = [ path('workspace//application//publish', views.ApplicationAPI.Publish.as_view()), path('workspace//application//application_key', views.ApplicationKey.as_view()), path('workspace//application//application_stats', views.ApplicationStats.as_view()), + path('workspace//application//application_token_usage', views.ApplicationStats.TokenUsageStatistics.as_view()), + path('workspace//application//top_questions', views.ApplicationStats.TopQuestionsStatistics.as_view()), path('workspace//application//application_key/', views.ApplicationKey.Operate.as_view()), path('workspace//application//export', views.ApplicationAPI.Export.as_view()), path('workspace//application//application_version', views.ApplicationVersionView.as_view()), diff --git a/apps/application/views/application_chat.py b/apps/application/views/application_chat.py index 381bde9dd..d60206ded 100644 --- a/apps/application/views/application_chat.py +++ b/apps/application/views/application_chat.py @@ -125,8 +125,8 @@ class OpenView(APIView): responses=None, tags=[_('Application')] # type: ignore ) - @has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(), - PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(), + @has_permissions(PermissionConstants.APPLICATION_READ.get_workspace_application_permission(), + PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(), ViewPermission([RoleConstants.USER.get_workspace_role()], [PermissionConstants.APPLICATION.get_workspace_application_permission()], CompareConstants.AND), @@ -167,8 +167,8 @@ class PromptGenerateView(APIView): responses=None, tags=[_('Application')] # type: ignore ) - @has_permissions(PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(), - PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(), + @has_permissions(PermissionConstants.APPLICATION_READ.get_workspace_application_permission(), + PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(), ViewPermission([RoleConstants.USER.get_workspace_role()], [PermissionConstants.APPLICATION.get_workspace_application_permission()], CompareConstants.AND), diff --git a/apps/application/views/application_chat_record.py b/apps/application/views/application_chat_record.py index dbcb246e4..0d59146b2 100644 --- a/apps/application/views/application_chat_record.py +++ b/apps/application/views/application_chat_record.py @@ -93,8 +93,8 @@ class ApplicationChatRecordOperateAPI(APIView): ) @has_permissions(PermissionConstants.APPLICATION_CHAT_LOG_READ.get_workspace_application_permission(), PermissionConstants.APPLICATION_CHAT_LOG_READ.get_workspace_permission_workspace_manage_role(), - PermissionConstants.APPLICATION_EDIT.get_workspace_application_permission(), - PermissionConstants.APPLICATION_EDIT.get_workspace_permission_workspace_manage_role(), + PermissionConstants.APPLICATION_READ.get_workspace_application_permission(), + PermissionConstants.APPLICATION_READ.get_workspace_permission_workspace_manage_role(), ViewPermission([RoleConstants.USER.get_workspace_role()], [PermissionConstants.APPLICATION.get_workspace_application_permission()], CompareConstants.AND), diff --git a/apps/application/views/application_stats.py b/apps/application/views/application_stats.py index 17b43fe37..456715617 100644 --- a/apps/application/views/application_stats.py +++ b/apps/application/views/application_stats.py @@ -46,3 +46,58 @@ class ApplicationStats(APIView): 'end_time': request.query_params.get( 'end_time') }).get_chat_record_aggregate_trend()) + + class TokenUsageStatistics(APIView): + authentication_classes = [TokenAuth] + + # 应用的token使用统计 根据人的使用数排序 + @extend_schema( + methods=['GET'], + description=_('Application token usage statistics'), + summary=_('Application token usage statistics'), + operation_id=_('Application token usage statistics'), # type: ignore + parameters=ApplicationStatsAPI.get_parameters(), + responses=ApplicationStatsAPI.get_response(), + tags=[_('Application')] # type: ignore + ) + @has_permissions(PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_application_permission(), + PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_permission_workspace_manage_role(), + ViewPermission([RoleConstants.USER.get_workspace_role()], + [PermissionConstants.APPLICATION.get_workspace_application_permission()], + CompareConstants.AND), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role()) + def get(self, request: Request, workspace_id: str, application_id: str): + return result.success( + ApplicationStatisticsSerializer(data={'application_id': application_id, 'workspace_id': workspace_id, + 'start_time': request.query_params.get( + 'start_time'), + 'end_time': request.query_params.get( + 'end_time') + }).get_token_usage_statistics()) + + class TopQuestionsStatistics(APIView): + authentication_classes = [TokenAuth] + # 应用的top问题统计 + @extend_schema( + methods=['GET'], + description=_('Application top question statistics'), + summary=_('Application top question statistics'), + operation_id=_('Application top question statistics'), # type: ignore + parameters=ApplicationStatsAPI.get_parameters(), + responses=ApplicationStatsAPI.get_response(), + tags=[_('Application')] # type: ignore + ) + @has_permissions(PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_application_permission(), + PermissionConstants.APPLICATION_OVERVIEW_READ.get_workspace_permission_workspace_manage_role(), + ViewPermission([RoleConstants.USER.get_workspace_role()], + [PermissionConstants.APPLICATION.get_workspace_application_permission()], + CompareConstants.AND), + RoleConstants.WORKSPACE_MANAGE.get_workspace_role()) + def get(self, request: Request, workspace_id: str, application_id: str): + return result.success( + ApplicationStatisticsSerializer(data={'application_id': application_id, 'workspace_id': workspace_id, + 'start_time': request.query_params.get( + 'start_time'), + 'end_time': request.query_params.get( + 'end_time') + }).get_top_questions_statistics()) diff --git a/apps/common/auth/authenticate.py b/apps/common/auth/authenticate.py index a762fed11..30bccf590 100644 --- a/apps/common/auth/authenticate.py +++ b/apps/common/auth/authenticate.py @@ -6,7 +6,6 @@ @date:2023/9/4 11:16 @desc: 认证类 """ -import traceback from importlib import import_module from django.conf import settings @@ -18,6 +17,7 @@ from rest_framework.authentication import TokenAuthentication from common.exception.app_exception import AppAuthenticationFailed, AppEmbedIdentityFailed, AppChatNumOutOfBoundsFailed, \ AppApiException +from common.utils.logger import maxkb_logger token_cache = cache.caches['default'] @@ -88,7 +88,7 @@ class TokenAuth(TokenAuthentication): return handle.handle(request, token, token_details.get_token_details) raise AppAuthenticationFailed(1002, _('Authentication information is incorrect! illegal user')) except Exception as e: - traceback.print_stack() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed) or isinstance(e, AppApiException): raise e diff --git a/apps/common/auth/common.py b/apps/common/auth/common.py index 40158f7d2..ad8e0e50a 100644 --- a/apps/common/auth/common.py +++ b/apps/common/auth/common.py @@ -45,7 +45,7 @@ class ChatAuthentication: value = json.dumps(self.to_dict()) authentication = encrypt(value) cache_key = hashlib.sha256(authentication.encode()).hexdigest() - authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.value, timeout=60 * 60 * 2) + authentication_cache.set(cache_key, value, version=Cache_Version.CHAT.get_version(), timeout=60 * 60 * 2) return authentication @staticmethod diff --git a/apps/common/config/tokenizer_manage_config.py b/apps/common/config/tokenizer_manage_config.py index 47a6d61e9..9a3ae73f2 100644 --- a/apps/common/config/tokenizer_manage_config.py +++ b/apps/common/config/tokenizer_manage_config.py @@ -7,18 +7,24 @@ @desc: """ +import os + +class MKTokenizer: + def __init__(self, tokenizer): + self.tokenizer = tokenizer + + def encode(self, text): + return self.tokenizer.encode(text).ids + class TokenizerManage: tokenizer = None @staticmethod def get_tokenizer(): - from transformers import BertTokenizer - if TokenizerManage.tokenizer is None: - TokenizerManage.tokenizer = BertTokenizer.from_pretrained( - 'bert-base-cased', - cache_dir="/opt/maxkb-app/model/tokenizer", - local_files_only=True, - resume_download=False, - force_download=False) - return TokenizerManage.tokenizer + from tokenizers import Tokenizer + # 创建Tokenizer + model_path = os.path.join("/opt/maxkb-app", "model", "tokenizer", "models--bert-base-cased") + with open(f"{model_path}/refs/main", encoding="utf-8") as f: snapshot = f.read() + TokenizerManage.tokenizer = Tokenizer.from_file(f"{model_path}/snapshots/{snapshot}/tokenizer.json") + return MKTokenizer(TokenizerManage.tokenizer) diff --git a/apps/common/constants/cache_version.py b/apps/common/constants/cache_version.py index 2cf889c17..6664acb56 100644 --- a/apps/common/constants/cache_version.py +++ b/apps/common/constants/cache_version.py @@ -30,6 +30,8 @@ class Cache_Version(Enum): # 对话 CHAT = "CHAT", lambda key: key + CHAT_INFO = "CHAT_INFO", lambda key: key + CHAT_VARIABLE = "CHAT_VARIABLE", lambda key: key # 应用API KEY diff --git a/apps/common/handle/impl/text/html_split_handle.py b/apps/common/handle/impl/text/html_split_handle.py index a82cfdaec..6ac8c44d9 100644 --- a/apps/common/handle/impl/text/html_split_handle.py +++ b/apps/common/handle/impl/text/html_split_handle.py @@ -77,5 +77,5 @@ class HTMLSplitHandle(BaseSplitHandle): content = buffer.decode(encoding) return html2text(content) except BaseException as e: - traceback.print_exception(e) + maxkb_logger.error(f'Exception: {e}', exc_info=True) return f'{e}' diff --git a/apps/common/management/commands/services/command.py b/apps/common/management/commands/services/command.py index 05095e5cb..5dfb7570d 100644 --- a/apps/common/management/commands/services/command.py +++ b/apps/common/management/commands/services/command.py @@ -11,7 +11,6 @@ class Services(TextChoices): gunicorn = 'gunicorn', 'gunicorn' celery_default = 'celery_default', 'celery_default' local_model = 'local_model', 'local_model' - scheduler = 'scheduler', 'scheduler' web = 'web', 'web' celery = 'celery', 'celery' celery_model = 'celery_model', 'celery_model' @@ -25,7 +24,6 @@ class Services(TextChoices): cls.gunicorn.value: services.GunicornService, cls.celery_default: services.CeleryDefaultService, cls.local_model: services.GunicornLocalModelService, - cls.scheduler: services.SchedulerService, } return services_map.get(name) @@ -41,13 +39,10 @@ class Services(TextChoices): def task_services(cls): return cls.celery_services() - @classmethod - def scheduler_services(cls): - return [cls.scheduler] @classmethod def all_services(cls): - return cls.web_services() + cls.task_services() + cls.scheduler_services() + return cls.web_services() + cls.task_services() @classmethod def export_services_values(cls): @@ -101,7 +96,7 @@ class BaseActionCommand(BaseCommand): ) parser.add_argument('-d', '--daemon', nargs="?", const=True) parser.add_argument('-w', '--worker', type=int, nargs="?", - default=2 if os.cpu_count() > 6 else math.floor(os.cpu_count() / 2)) + default=3 if os.cpu_count() > 6 else max(1, math.floor(os.cpu_count() / 2))) parser.add_argument('-f', '--force', nargs="?", const=True) def initial_util(self, *args, **options): diff --git a/apps/common/management/commands/services/services/gunicorn.py b/apps/common/management/commands/services/services/gunicorn.py index b5b8673da..4a7233920 100644 --- a/apps/common/management/commands/services/services/gunicorn.py +++ b/apps/common/management/commands/services/services/gunicorn.py @@ -18,14 +18,17 @@ class GunicornService(BaseService): log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' bind = f'{HTTP_HOST}:{HTTP_PORT}' + max_requests = 10240 if int(self.worker) > 1 else 0 cmd = [ 'gunicorn', 'maxkb.wsgi:application', '-b', bind, '-k', 'gthread', '--threads', '200', '-w', str(self.worker), - '--max-requests', '10240', + '--max-requests', str(max_requests), '--max-requests-jitter', '2048', + '--timeout', '0', + '--graceful-timeout', '0', '--access-logformat', log_format, '--access-logfile', '/dev/null', '--error-logfile', '-' diff --git a/apps/common/management/commands/services/services/local_model.py b/apps/common/management/commands/services/services/local_model.py index f523b1b1f..a37f0d16b 100644 --- a/apps/common/management/commands/services/services/local_model.py +++ b/apps/common/management/commands/services/services/local_model.py @@ -27,14 +27,17 @@ class GunicornLocalModelService(BaseService): log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' worker = CONFIG.get("LOCAL_MODEL_HOST_WORKER", 1) + max_requests = 10240 if int(worker) > 1 else 0 cmd = [ 'gunicorn', 'maxkb.wsgi:application', '-b', bind, '-k', 'gthread', '--threads', '200', '-w', str(worker), - '--max-requests', '10240', + '--max-requests', str(max_requests), '--max-requests-jitter', '2048', + '--timeout', '0', + '--graceful-timeout', '0', '--access-logformat', log_format, '--access-logfile', '/dev/null', '--error-logfile', '-' diff --git a/apps/common/management/commands/services/services/scheduler.py b/apps/common/management/commands/services/services/scheduler.py index e9a0bd97a..dcef849b8 100644 --- a/apps/common/management/commands/services/services/scheduler.py +++ b/apps/common/management/commands/services/services/scheduler.py @@ -18,14 +18,17 @@ class SchedulerService(BaseService): log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' bind = f'127.0.0.1:6060' + max_requests = 10240 if int(self.worker) > 1 else 0 cmd = [ 'gunicorn', 'maxkb.wsgi:application', '-b', bind, '-k', 'gthread', '--threads', '200', '-w', str(self.worker), - '--max-requests', '10240', + '--max-requests', str(max_requests), '--max-requests-jitter', '2048', + '--timeout', '0', + '--graceful-timeout', '0', '--access-logformat', log_format, '--access-logfile', '/dev/null', '--error-logfile', '-' diff --git a/apps/common/mixins/app_model_mixin.py b/apps/common/mixins/app_model_mixin.py index b17d233bd..86d334c52 100644 --- a/apps/common/mixins/app_model_mixin.py +++ b/apps/common/mixins/app_model_mixin.py @@ -10,6 +10,8 @@ from django.db import models class AppModelMixin(models.Model): + objects = models.Manager() + create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True, db_index=True) update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True, db_index=True) diff --git a/apps/common/utils/rsa_util.py b/apps/common/utils/rsa_util.py index 5631be50e..1eba2389a 100644 --- a/apps/common/utils/rsa_util.py +++ b/apps/common/utils/rsa_util.py @@ -8,6 +8,7 @@ """ import base64 import threading +from functools import lru_cache from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher from Crypto.PublicKey import RSA @@ -70,7 +71,7 @@ def encrypt(msg, public_key: str | None = None): """ if public_key is None: public_key = get_key_pair().get('key') - cipher = PKCS1_cipher.new(RSA.importKey(public_key)) + cipher = _get_encrypt_cipher(public_key) encrypt_msg = cipher.encrypt(msg.encode("utf-8")) return base64.b64encode(encrypt_msg).decode() @@ -84,56 +85,69 @@ def decrypt(msg, pri_key: str | None = None): """ if pri_key is None: pri_key = get_key_pair().get('value') - cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + cipher = _get_cipher(pri_key) decrypt_data = cipher.decrypt(base64.b64decode(msg), 0) return decrypt_data.decode("utf-8") + +@lru_cache(maxsize=2) +def _get_encrypt_cipher(public_key: str): + """缓存加密 cipher 对象""" + return PKCS1_cipher.new(RSA.importKey(extern_key=public_key, passphrase=secret_code)) + + def rsa_long_encrypt(message, public_key: str | None = None, length=200): """ 超长文本加密 :param message: 需要加密的字符串 :param public_key 公钥 - :param length: 1024bit的证书用100, 2048bit的证书用 200 + :param length: 1024bit的证书用100, 2048bit的证书用 200 :return: 加密后的数据 """ - # 读取公钥 if public_key is None: public_key = get_key_pair().get('key') - cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, - passphrase=secret_code)) - # 处理:Plaintext is too long. 分段加密 + + cipher = _get_encrypt_cipher(public_key) + if len(message) <= length: - # 对编码的数据进行加密,并通过base64进行编码 result = base64.b64encode(cipher.encrypt(message.encode('utf-8'))) else: rsa_text = [] - # 对编码后的数据进行切片,原因:加密长度不能过长 for i in range(0, len(message), length): cont = message[i:i + length] - # 对切片后的数据进行加密,并新增到text后面 rsa_text.append(cipher.encrypt(cont.encode('utf-8'))) - # 加密完进行拼接 cipher_text = b''.join(rsa_text) - # base64进行编码 result = base64.b64encode(cipher_text) + return result.decode() +@lru_cache(maxsize=2) +def _get_cipher(pri_key: str): + """缓存 cipher 对象,避免重复创建""" + return PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + + def rsa_long_decrypt(message, pri_key: str | None = None, length=256): """ - 超长文本解密,默认不加密 + 超长文本解密,优化内存使用 :param message: 需要解密的数据 :param pri_key: 秘钥 - :param length : 1024bit的证书用128,2048bit证书用256位 + :param length : 1024bit的证书用128,2048bit证书用256位 :return: 解密后的数据 """ if pri_key is None: pri_key = get_key_pair().get('value') - cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + + cipher = _get_cipher(pri_key) base64_de = base64.b64decode(message) - res = [] + + # 使用 bytearray 减少内存分配 + result = bytearray() for i in range(0, len(base64_de), length): - res.append(cipher.decrypt(base64_de[i:i + length], 0)) - return b"".join(res).decode() + result.extend(cipher.decrypt(base64_de[i:i + length], 0)) + + return result.decode() + diff --git a/apps/common/utils/tool_code.py b/apps/common/utils/tool_code.py index f8e819557..902622647 100644 --- a/apps/common/utils/tool_code.py +++ b/apps/common/utils/tool_code.py @@ -1,51 +1,77 @@ # coding=utf-8 import ast +import base64 +import gzip import json import os +import socket import subprocess import sys -from textwrap import dedent -import socket import uuid_utils.compat as uuid +from common.utils.logger import maxkb_logger from django.utils.translation import gettext_lazy as _ - from maxkb.const import BASE_DIR, CONFIG from maxkb.const import PROJECT_DIR +from textwrap import dedent python_directory = sys.executable class ToolExecutor: + def __init__(self, sandbox=False): self.sandbox = sandbox if sandbox: - self.sandbox_path = '/opt/maxkb-app/sandbox' + self.sandbox_path = CONFIG.get("SANDBOX_HOME", '/opt/maxkb-app/sandbox') self.user = 'sandbox' else: self.sandbox_path = os.path.join(PROJECT_DIR, 'data', 'sandbox') self.user = None - self._createdir() - if self.sandbox: - os.system(f"chown -R {self.user}:root {self.sandbox_path}") self.banned_keywords = CONFIG.get("SANDBOX_PYTHON_BANNED_KEYWORDS", 'nothing_is_banned').split(','); - banned_hosts = CONFIG.get("SANDBOX_PYTHON_BANNED_HOSTS", '').strip() + self.sandbox_so_path = f'{self.sandbox_path}/sandbox.so' try: - if banned_hosts: - hostname = socket.gethostname() - local_ip = socket.gethostbyname(hostname) - banned_hosts = f"{banned_hosts},{hostname},{local_ip}" - except Exception: - pass - self.banned_hosts = banned_hosts + self._init_dir() + except Exception as e: + # 本机忽略异常,容器内不忽略 + maxkb_logger.error(f'Exception: {e}', exc_info=True) + if self.sandbox: + raise e - def _createdir(self): - old_mask = os.umask(0o077) + def _init_dir(self): try: - os.makedirs(self.sandbox_path, 0o700, exist_ok=True) - os.makedirs(os.path.join(self.sandbox_path, 'execute'), 0o700, exist_ok=True) - os.makedirs(os.path.join(self.sandbox_path, 'result'), 0o700, exist_ok=True) - finally: - os.umask(old_mask) + # 只初始化一次 + fd = os.open(os.path.join(PROJECT_DIR, 'tmp', 'tool_executor_init_dir.lock'), + os.O_CREAT | os.O_EXCL | os.O_WRONLY) + os.close(fd) + except FileExistsError: + # 文件已存在 → 已初始化过 + return + maxkb_logger.debug("init dir") + if self.sandbox: + try: + os.system("chmod -R g-rwx /dev/shm /dev/mqueue") + os.system("chmod o-rwx /run/postgresql") + except Exception as e: + maxkb_logger.warning(f'Exception: {e}', exc_info=True) + pass + if CONFIG.get("SANDBOX_TMP_DIR_ENABLED", '0') == "1": + tmp_dir_path = os.path.join(self.sandbox_path, 'tmp') + os.makedirs(tmp_dir_path, 0o700, exist_ok=True) + os.system(f"chown -R {self.user}:root {tmp_dir_path}") + if os.path.exists(self.sandbox_so_path): + os.chmod(self.sandbox_so_path, 0o440) + # 初始化host黑名单 + banned_hosts_file_path = f'{self.sandbox_path}/.SANDBOX_BANNED_HOSTS' + if os.path.exists(banned_hosts_file_path): + os.remove(banned_hosts_file_path) + banned_hosts = CONFIG.get("SANDBOX_PYTHON_BANNED_HOSTS", '').strip() + if banned_hosts: + hostname = socket.gethostname() + local_ip = socket.gethostbyname(hostname) + banned_hosts = f"{banned_hosts},{hostname},{local_ip}" + with open(banned_hosts_file_path, "w") as f: + f.write(banned_hosts) + os.chmod(banned_hosts_file_path, 0o440) def exec_code(self, code_str, keywords, function_name=None): self.validate_banned_keywords(code_str) @@ -57,9 +83,7 @@ class ToolExecutor: python_paths = CONFIG.get_sandbox_python_package_paths().split(',') _exec_code = f""" try: - import os - import sys - import json + import sys, json, base64, builtins path_to_exclude = ['/opt/py3/lib/python3.11/site-packages', '/opt/maxkb-app/apps'] sys.path = [p for p in sys.path if p not in path_to_exclude] sys.path += {python_paths} @@ -71,21 +95,21 @@ try: for local in locals_v: globals_v[local] = locals_v[local] exec_result=f(**keywords) - with open({result_path!a}, 'w') as file: - file.write(json.dumps({success}, default=str)) + builtins.print("\\n{_id}:"+base64.b64encode(json.dumps({success}, default=str).encode()).decode()) except Exception as e: - with open({result_path!a}, 'w') as file: - file.write(json.dumps({err})) + builtins.print("\\n{_id}:"+base64.b64encode(json.dumps({err}, default=str).encode()).decode()) """ if self.sandbox: - subprocess_result = self._exec_sandbox(_exec_code, _id) + subprocess_result = self._exec_sandbox(_exec_code) else: subprocess_result = self._exec(_exec_code) if subprocess_result.returncode == 1: raise Exception(subprocess_result.stderr) - with open(result_path, 'r') as file: - result = json.loads(file.read()) - os.remove(result_path) + lines = subprocess_result.stdout.splitlines() + result_line = [line for line in lines if line.startswith(_id)] + if not result_line: + raise Exception("No result found.") + result = json.loads(base64.b64decode(result_line[-1].split(":", 1)[1]).decode()) if result.get('code') == 200: return result.get('data') raise Exception(result.get('msg')) @@ -173,52 +197,45 @@ exec({dedent(code)!a}) """ def get_tool_mcp_config(self, code, params): - code = self.generate_mcp_server_code(code, params) - - _id = uuid.uuid7() - code_path = f'{self.sandbox_path}/execute/{_id}.py' - with open(code_path, 'w') as f: - f.write(code) + _code = self.generate_mcp_server_code(code, params) + maxkb_logger.debug(f"Python code of mcp tool: {_code}") + compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode() if self.sandbox: - os.system(f"chown {self.user}:root {code_path}") - tool_config = { 'command': 'su', 'args': [ '-s', sys.executable, - '-c', f"exec(open('{code_path}', 'r').read())", + '-c', + f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())', self.user, ], 'cwd': self.sandbox_path, 'env': { - 'LD_PRELOAD': '/opt/maxkb-app/sandbox/sandbox.so', - 'SANDBOX_BANNED_HOSTS': self.banned_hosts, + 'LD_PRELOAD': self.sandbox_so_path, }, 'transport': 'stdio', } else: tool_config = { 'command': sys.executable, - 'args': [code_path], + 'args': f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())', 'transport': 'stdio', } - return _id, tool_config + return tool_config - def _exec_sandbox(self, _code, _id): - exec_python_file = f'{self.sandbox_path}/execute/{_id}.py' - with open(exec_python_file, 'w') as file: - file.write(_code) - os.system(f"chown {self.user}:root {exec_python_file}") + def _exec_sandbox(self, _code): kwargs = {'cwd': BASE_DIR} kwargs['env'] = { - 'LD_PRELOAD': '/opt/maxkb-app/sandbox/sandbox.so', - 'SANDBOX_BANNED_HOSTS': self.banned_hosts, + 'LD_PRELOAD': self.sandbox_so_path, } + maxkb_logger.debug(f"Sandbox execute code: {_code}") + compressed_and_base64_encoded_code_str = base64.b64encode(gzip.compress(_code.encode())).decode() subprocess_result = subprocess.run( - ['su', '-s', python_directory, '-c', "exec(open('" + exec_python_file + "').read())", self.user], + ['su', '-s', python_directory, '-c', + f'import base64,gzip; exec(gzip.decompress(base64.b64decode(\'{compressed_and_base64_encoded_code_str}\')).decode())', + self.user], text=True, capture_output=True, **kwargs) - os.remove(exec_python_file) return subprocess_result def validate_banned_keywords(self, code_str): diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py index 8f568c110..7e9fa4237 100644 --- a/apps/knowledge/serializers/document.py +++ b/apps/knowledge/serializers/document.py @@ -816,7 +816,7 @@ class DocumentSerializers(serializers.Serializer): @post(post_function=post_embedding) @transaction.atomic - def save(self, instance: Dict, with_valid=False, **kwargs): + def save(self, instance: Dict, with_valid=True, **kwargs): if with_valid: DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True) self.is_valid(raise_exception=True) diff --git a/apps/knowledge/task/sync.py b/apps/knowledge/task/sync.py index 20d92386e..a113f7b48 100644 --- a/apps/knowledge/task/sync.py +++ b/apps/knowledge/task/sync.py @@ -7,7 +7,6 @@ @desc: """ -import logging import traceback from typing import List diff --git a/apps/knowledge/views/document.py b/apps/knowledge/views/document.py index 88f529f1a..1b3301d69 100644 --- a/apps/knowledge/views/document.py +++ b/apps/knowledge/views/document.py @@ -542,9 +542,9 @@ class DocumentView(APIView): @extend_schema( methods=['PUT'], - summary=_('Batch generate related documents'), - description=_('Batch generate related documents'), - operation_id=_('Batch generate related documents'), # type: ignore + summary=_('Batch generate related problems'), + description=_('Batch generate related problems'), + operation_id=_('Batch generate related problems'), # type: ignore request=BatchGenerateRelatedAPI.get_request(), parameters=BatchGenerateRelatedAPI.get_parameters(), responses=BatchGenerateRelatedAPI.get_response(), @@ -560,7 +560,7 @@ class DocumentView(APIView): [PermissionConstants.KNOWLEDGE.get_workspace_knowledge_permission()], CompareConstants.AND), ) @log( - menu='document', operate="Batch generate related documents", + menu='document', operate="Batch generate related problems", get_operation_object=lambda r, keywords: get_knowledge_document_operation_object( get_knowledge_operation_object(keywords.get('knowledge_id')), get_document_operation_object_batch(r.data.get('document_id_list')) diff --git a/apps/local_model/serializers/model_apply_serializers.py b/apps/local_model/serializers/model_apply_serializers.py index cf57c2fef..5ecb2260c 100644 --- a/apps/local_model/serializers/model_apply_serializers.py +++ b/apps/local_model/serializers/model_apply_serializers.py @@ -74,7 +74,6 @@ class ModelManage: def get_local_model(model, **kwargs): - # system_setting = QuerySet(SystemSetting).filter(type=1).first() return LocalModelProvider().get_model(model.model_type, model.model_name, json.loads( rsa_long_decrypt(model.credential)), @@ -111,6 +110,21 @@ class CompressDocuments(serializers.Serializer): query = serializers.CharField(required=True, label=_('query')) +class ValidateModelSerializers(serializers.Serializer): + model_name = serializers.CharField(required=True, label=_('model_name')) + + model_type = serializers.CharField(required=True, label=_('model_type')) + + model_credential = serializers.DictField(required=True, label="credential") + + def validate_model(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + LocalModelProvider().is_valid_credential(self.data.get('model_type'), self.data.get('model_name'), + self.data.get('model_credential'), model_params={}, + raise_exception=True) + + class ModelApplySerializers(serializers.Serializer): model_id = serializers.UUIDField(required=True, label=_('model id')) @@ -138,3 +152,9 @@ class ModelApplySerializers(serializers.Serializer): return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents( [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in instance.get('documents')], instance.get('query'))] + + def unload(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + ModelManage.delete_key(self.data.get('model_id')) + return True diff --git a/apps/local_model/urls.py b/apps/local_model/urls.py index cc47946e0..a9c060254 100644 --- a/apps/local_model/urls.py +++ b/apps/local_model/urls.py @@ -7,7 +7,9 @@ from . import views app_name = "local_model" # @formatter:off urlpatterns = [ + path('model/validate', views.LocalModelApply.Validate.as_view()), path('model//embed_documents', views.LocalModelApply.EmbedDocuments.as_view()), path('model//embed_query', views.LocalModelApply.EmbedQuery.as_view()), path('model//compress_documents', views.LocalModelApply.CompressDocuments.as_view()), + path('model//unload', views.LocalModelApply.Unload.as_view()), ] diff --git a/apps/local_model/views/model_apply.py b/apps/local_model/views/model_apply.py index 218d4f091..98c07dd74 100644 --- a/apps/local_model/views/model_apply.py +++ b/apps/local_model/views/model_apply.py @@ -11,7 +11,7 @@ from urllib.request import Request from rest_framework.views import APIView from common.result import result -from local_model.serializers.model_apply_serializers import ModelApplySerializers +from local_model.serializers.model_apply_serializers import ModelApplySerializers, ValidateModelSerializers class LocalModelApply(APIView): @@ -32,3 +32,12 @@ class LocalModelApply(APIView): def post(self, request: Request, model_id): return result.success( ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) + + class Unload(APIView): + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) + + class Validate(APIView): + def post(self, request: Request): + return result.success(ValidateModelSerializers(data=request.data).validate_model()) diff --git a/apps/locales/en_US/LC_MESSAGES/django.po b/apps/locales/en_US/LC_MESSAGES/django.po index a24494790..44808ce4d 100644 --- a/apps/locales/en_US/LC_MESSAGES/django.po +++ b/apps/locales/en_US/LC_MESSAGES/django.po @@ -2945,7 +2945,7 @@ msgstr "" #: apps/knowledge/views/document.py:466 apps/knowledge/views/document.py:467 #: apps/knowledge/views/document.py:468 -msgid "Batch generate related documents" +msgid "Batch generate related problems" msgstr "" #: apps/knowledge/views/document.py:496 apps/knowledge/views/document.py:497 @@ -5126,7 +5126,7 @@ msgstr "" #: apps/resource_manage/views/document.py:261 #: apps/resource_manage/views/document.py:262 #: apps/resource_manage/views/document.py:263 -msgid "Batch generate related system document" +msgid "Batch generate related system problems" msgstr "" #: apps/resource_manage/views/document.py:290 @@ -5946,7 +5946,7 @@ msgstr "" #: apps/shared/views/shared_document.py:262 #: apps/shared/views/shared_document.py:263 #: apps/shared/views/shared_document.py:264 -msgid "Batch generate related shared documents" +msgid "Batch generate related shared problems" msgstr "" #: apps/shared/views/shared_document.py:291 @@ -6671,7 +6671,7 @@ msgid "Phone" msgstr "" #: apps/users/serializers/user.py:120 apps/xpack/serializers/chat_user.py:68 -msgid "Username must be 4-20 characters long" +msgid "Username must be 4-64 characters long" msgstr "" #: apps/users/serializers/user.py:133 apps/users/serializers/user.py:298 @@ -8766,4 +8766,10 @@ msgid "Non-existent id" msgstr "" msgid "No permission for the target folder" +msgstr "" + +msgid "Application token usage statistics" +msgstr "" + +msgid "Application top question statistics" msgstr "" \ No newline at end of file diff --git a/apps/locales/zh_CN/LC_MESSAGES/django.po b/apps/locales/zh_CN/LC_MESSAGES/django.po index 793bef404..a2c9298a8 100644 --- a/apps/locales/zh_CN/LC_MESSAGES/django.po +++ b/apps/locales/zh_CN/LC_MESSAGES/django.po @@ -2952,8 +2952,8 @@ msgstr "批量刷新文档向量库" #: apps/knowledge/views/document.py:466 apps/knowledge/views/document.py:467 #: apps/knowledge/views/document.py:468 -msgid "Batch generate related documents" -msgstr "批量生成相关文档" +msgid "Batch generate related problems" +msgstr "批量生成相关问题" #: apps/knowledge/views/document.py:496 apps/knowledge/views/document.py:497 #: apps/knowledge/views/document.py:498 @@ -5251,8 +5251,8 @@ msgstr "批量刷新文档向量库" #: apps/resource_manage/views/document.py:261 #: apps/resource_manage/views/document.py:262 #: apps/resource_manage/views/document.py:263 -msgid "Batch generate related system document" -msgstr "批量生成相关文档" +msgid "Batch generate related system problems" +msgstr "批量生成相关问题" #: apps/resource_manage/views/document.py:290 #: apps/resource_manage/views/document.py:291 @@ -6071,8 +6071,8 @@ msgstr "批量刷新共享文档向量库" #: apps/shared/views/shared_document.py:262 #: apps/shared/views/shared_document.py:263 #: apps/shared/views/shared_document.py:264 -msgid "Batch generate related shared documents" -msgstr "批量生成相关共享文档" +msgid "Batch generate related shared problems" +msgstr "批量生成相关问题" #: apps/shared/views/shared_document.py:291 #: apps/shared/views/shared_document.py:292 @@ -6796,8 +6796,8 @@ msgid "Phone" msgstr "手机" #: apps/users/serializers/user.py:120 apps/xpack/serializers/chat_user.py:68 -msgid "Username must be 4-20 characters long" -msgstr "用户名必须为4-20个字符" +msgid "Username must be 4-64 characters long" +msgstr "用户名必须为4-64个字符" #: apps/users/serializers/user.py:133 apps/users/serializers/user.py:298 #: apps/xpack/serializers/chat_user.py:81 @@ -8894,3 +8894,8 @@ msgstr "不存在的ID" msgid "No permission for the target folder" msgstr "没有目标文件夹的权限" +msgid "Application token usage statistics" +msgstr "应用令牌使用统计" + +msgid "Application top question statistics" +msgstr "应用提问次数统计" diff --git a/apps/locales/zh_Hant/LC_MESSAGES/django.po b/apps/locales/zh_Hant/LC_MESSAGES/django.po index e54373750..aa21afb30 100644 --- a/apps/locales/zh_Hant/LC_MESSAGES/django.po +++ b/apps/locales/zh_Hant/LC_MESSAGES/django.po @@ -2952,8 +2952,8 @@ msgstr "批量刷新文檔向量庫" #: apps/knowledge/views/document.py:466 apps/knowledge/views/document.py:467 #: apps/knowledge/views/document.py:468 -msgid "Batch generate related documents" -msgstr "批量生成相關文檔" +msgid "Batch generate related problems" +msgstr "批量生成相關問題" #: apps/knowledge/views/document.py:496 apps/knowledge/views/document.py:497 #: apps/knowledge/views/document.py:498 @@ -5251,8 +5251,8 @@ msgstr "批量刷新文檔向量庫" #: apps/resource_manage/views/document.py:261 #: apps/resource_manage/views/document.py:262 #: apps/resource_manage/views/document.py:263 -msgid "Batch generate related system document" -msgstr "批量生成相關文檔" +msgid "Batch generate related system problems" +msgstr "批量生成相關問題" #: apps/resource_manage/views/document.py:290 #: apps/resource_manage/views/document.py:291 @@ -6071,8 +6071,8 @@ msgstr "批量刷新共享文檔向量庫" #: apps/shared/views/shared_document.py:262 #: apps/shared/views/shared_document.py:263 #: apps/shared/views/shared_document.py:264 -msgid "Batch generate related shared documents" -msgstr "批量生成相關共享文檔" +msgid "Batch generate related shared problems" +msgstr "批量生成相關問題" #: apps/shared/views/shared_document.py:291 #: apps/shared/views/shared_document.py:292 @@ -6796,8 +6796,8 @@ msgid "Phone" msgstr "手機" #: apps/users/serializers/user.py:120 apps/xpack/serializers/chat_user.py:68 -msgid "Username must be 4-20 characters long" -msgstr "用戶名必須為4-20個字符" +msgid "Username must be 4-64 characters long" +msgstr "用戶名必須為4-64個字符" #: apps/users/serializers/user.py:133 apps/users/serializers/user.py:298 #: apps/xpack/serializers/chat_user.py:81 @@ -8894,3 +8894,8 @@ msgstr "不存在的ID" msgid "No permission for the target folder" msgstr "沒有目標資料夾的權限" +msgid "Application token usage statistics" +msgstr "應用令牌使用統計" + +msgid "Application top question statistics" +msgstr "應用提問次數統計" \ No newline at end of file diff --git a/apps/maxkb/wsgi/web.py b/apps/maxkb/wsgi/web.py index e95125e42..d1fa687d8 100644 --- a/apps/maxkb/wsgi/web.py +++ b/apps/maxkb/wsgi/web.py @@ -6,16 +6,36 @@ @date:2025/11/5 15:14 @desc: """ +import builtins import os +import sys from django.core.wsgi import get_wsgi_application + +class TorchBlocker: + def __init__(self): + self.original_import = builtins.__import__ + + def __call__(self, name, *args, **kwargs): + if len([True for i in + ['torch'] + if + i in name.lower()]) > 0: + print(f"Disable package is being imported: 【{name}】", file=sys.stderr) + pass + else: + return self.original_import(name, *args, **kwargs) + + +# 安装导入拦截器 +builtins.__import__ = TorchBlocker() + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'maxkb.settings') - +os.environ['TIKTOKEN_CACHE_DIR'] = '/opt/maxkb-app/model/tokenizer/openai-tiktoken-cl100k-base' application = get_wsgi_application() - def post_handler(): from common.database_model_manage.database_model_manage import DatabaseModelManage from common import event @@ -24,14 +44,5 @@ def post_handler(): DatabaseModelManage.init() -def post_scheduler_handler(): - from common import job - - job.run() - # 启动后处理函数 post_handler() - -# 仅在scheduler中启动定时任务,dev local_model celery 不需要 -if os.environ.get('ENABLE_SCHEDULER') == '1': - post_scheduler_handler() \ No newline at end of file diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/asr_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/asr_stt.py index 93374e47d..64f0aea5b 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/asr_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/asr_stt.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict, Any from common import forms @@ -7,7 +6,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode from django.utils.translation import gettext as _ - +from common.utils.logger import maxkb_logger class AliyunBaiLianAsrSTTModelCredential(BaseForm, BaseModelCredential): api_url = forms.TextInputField(_('API URL'), required=True) @@ -41,7 +40,7 @@ class AliyunBaiLianAsrSTTModelCredential(BaseForm, BaseModelCredential): try: model = provider.get_model(model_type, model_name, model_credential) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py index 0a3c39107..28127c055 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py @@ -6,7 +6,6 @@ @date:2024/10/16 17:01 @desc: """ -import traceback from typing import Dict, Any from django.utils.translation import gettext as _ @@ -16,6 +15,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding +from common.utils.logger import maxkb_logger class BaiLianEmbeddingModelParams(BaseForm): dimensions = forms.SingleSelect( @@ -69,7 +69,7 @@ class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential): model: AliyunBaiLianEmbedding = provider.get_model(model_type, model_name, model_credential) model.embed_query(_("Hello")) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py index 444b1ff2f..24657f6f4 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/image.py @@ -6,11 +6,9 @@ @date:2024/7/11 18:41 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext -from langchain_core.messages import HumanMessage from common import forms from common.exception.app_exception import AppApiException @@ -69,7 +67,7 @@ class QwenVLModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth(model_credential.get('api_key')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/itv.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/itv.py index 81c4e2286..7f38f49c4 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/itv.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/itv.py @@ -1,6 +1,5 @@ # coding=utf-8 -import traceback from typing import Dict, Any from django.utils.translation import gettext_lazy as _, gettext @@ -9,7 +8,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel from common.forms.switch_field import SwitchField from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class QwenModelParams(BaseForm): """ @@ -85,7 +84,7 @@ class ImageToVideoModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py index 53d9ebda9..9511e2ef0 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/llm.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from langchain_core.messages import HumanMessage @@ -9,7 +8,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class BaiLianLLMModelParams(BaseForm): temperature = forms.SliderField( @@ -76,7 +75,7 @@ class BaiLianLLMModelCredential(BaseForm, BaseModelCredential): else: model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/omni_stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/omni_stt.py index 0b704fbc4..6b1e2f894 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/omni_stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/omni_stt.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict, Any from common import forms @@ -7,6 +6,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm, PasswordInputField, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from django.utils.translation import gettext as _ +from common.utils.logger import maxkb_logger class AliyunBaiLianOmiSTTModelParams(BaseForm): CueWord = forms.TextInputField( @@ -49,7 +49,7 @@ class AliyunBaiLianOmiSTTModelCredential(BaseForm, BaseModelCredential): try: model = provider.get_model(model_type, model_name, model_credential) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py index f6131492a..a43b4007b 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py @@ -1,6 +1,5 @@ # coding=utf-8 -import traceback from typing import Dict, Any from django.utils.translation import gettext as _ @@ -10,7 +9,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm, PasswordInputField from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker - +from common.utils.logger import maxkb_logger class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential): """ @@ -60,7 +59,7 @@ class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential): model: AliyunBaiLianReranker = provider.get_model(model_type, model_name, model_credential) model.compress_documents([Document(page_content=_('Hello'))], _('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py index a6ee93912..dd2f56c23 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py @@ -1,6 +1,5 @@ # coding=utf-8 -import traceback from typing import Dict, Any from django.utils.translation import gettext as _ @@ -9,7 +8,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, PasswordInputField, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class AliyunBaiLianSTTModelParams(BaseForm): sample_rate = forms.SliderField( @@ -68,7 +67,7 @@ class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential,**model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py index 5731b5a6d..82f1d7185 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tti.py @@ -1,6 +1,5 @@ # coding=utf-8 -import traceback from typing import Dict, Any from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from django.utils.translation import gettext_lazy as _, gettext from common.exception.app_exception import AppApiException from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class QwenModelParams(BaseForm): """ @@ -110,7 +109,7 @@ class QwenTextToImageModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py index 681b670fc..089d25900 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py @@ -1,6 +1,5 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from django.utils.translation import gettext_lazy as _, gettext from common.exception.app_exception import AppApiException from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class AliyunBaiLianTTSModelGeneralParams(BaseForm): """ @@ -103,7 +102,7 @@ class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py index 14709bb90..b78f86ab9 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/credential/ttv.py @@ -1,6 +1,5 @@ # coding=utf-8 -import traceback from typing import Dict, Any from django.utils.translation import gettext_lazy as _, gettext @@ -9,7 +8,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm, PasswordInputField, SingleSelect, SliderField, TooltipLabel from common.forms.switch_field import SwitchField from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class QwenModelParams(BaseForm): """ @@ -87,7 +86,7 @@ class TextToVideoModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py index ece41f3dd..f48c0adf2 100644 --- a/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py +++ b/apps/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py @@ -54,7 +54,6 @@ class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText): 'callback': None, **self.params } - print(recognition_params) recognition = Recognition(**recognition_params) diff --git a/apps/models_provider/impl/anthropic_model_provider/credential/image.py b/apps/models_provider/impl/anthropic_model_provider/credential/image.py index ff824347e..9d0d01f00 100644 --- a/apps/models_provider/impl/anthropic_model_provider/credential/image.py +++ b/apps/models_provider/impl/anthropic_model_provider/credential/image.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -54,7 +53,7 @@ class AnthropicImageModelCredential(BaseForm, BaseModelCredential): for chunk in res: maxkb_logger.info(chunk) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/anthropic_model_provider/credential/llm.py b/apps/models_provider/impl/anthropic_model_provider/credential/llm.py index 38e0b3cf5..c90e8ff67 100644 --- a/apps/models_provider/impl/anthropic_model_provider/credential/llm.py +++ b/apps/models_provider/impl/anthropic_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/11 18:32 @desc: """ -import traceback from typing import Dict from langchain_core.messages import HumanMessage @@ -16,7 +15,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from django.utils.translation import gettext_lazy as _, gettext - +from common.utils.logger import maxkb_logger class AnthropicLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -56,7 +55,7 @@ class AnthropicLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py b/apps/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py index 0bb5e52af..8c73af139 100644 --- a/apps/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py @@ -1,4 +1,3 @@ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -8,7 +7,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel - +from common.utils.logger import maxkb_logger class BedrockEmbeddingCredential(BaseForm, BaseModelCredential): @@ -35,7 +34,7 @@ class BedrockEmbeddingCredential(BaseForm, BaseModelCredential): except AppApiException: raise except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if raise_exception: raise AppApiException(ValidCode.valid_error.value, _('Verification failed, please check whether the parameters are correct: {error}').format( diff --git a/apps/models_provider/impl/aws_bedrock_model_provider/credential/llm.py b/apps/models_provider/impl/aws_bedrock_model_provider/credential/llm.py index 358e66309..32527dac0 100644 --- a/apps/models_provider/impl/aws_bedrock_model_provider/credential/llm.py +++ b/apps/models_provider/impl/aws_bedrock_model_provider/credential/llm.py @@ -1,4 +1,3 @@ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import ValidCode, BaseModelCredential - +from common.utils.logger import maxkb_logger class BedrockLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -54,7 +53,7 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential): except AppApiException: raise except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if raise_exception: raise AppApiException(ValidCode.valid_error.value, gettext( diff --git a/apps/models_provider/impl/azure_model_provider/credential/embedding.py b/apps/models_provider/impl/azure_model_provider/credential/embedding.py index 329b7acbd..c37f750b1 100644 --- a/apps/models_provider/impl/azure_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/azure_model_provider/credential/embedding.py @@ -6,7 +6,6 @@ @date:2024/7/11 17:08 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -15,7 +14,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential): @@ -36,7 +35,7 @@ class AzureOpenAIEmbeddingCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.embed_query(_('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/azure_model_provider/credential/image.py b/apps/models_provider/impl/azure_model_provider/credential/image.py index b5d07d667..eefd0ea98 100644 --- a/apps/models_provider/impl/azure_model_provider/credential/image.py +++ b/apps/models_provider/impl/azure_model_provider/credential/image.py @@ -1,7 +1,6 @@ # coding=utf-8 import base64 import os -import traceback from typing import Dict from langchain_core.messages import HumanMessage @@ -57,7 +56,7 @@ class AzureOpenAIImageModelCredential(BaseForm, BaseModelCredential): for chunk in res: maxkb_logger.info(chunk) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/azure_model_provider/credential/llm.py b/apps/models_provider/impl/azure_model_provider/credential/llm.py index 1e1967f6f..f9f3ad86b 100644 --- a/apps/models_provider/impl/azure_model_provider/credential/llm.py +++ b/apps/models_provider/impl/azure_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/11 17:08 @desc: """ -import traceback from typing import Dict from langchain_core.messages import HumanMessage @@ -17,7 +16,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode from django.utils.translation import gettext_lazy as _, gettext - +from common.utils.logger import maxkb_logger class AzureLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -68,7 +67,7 @@ class AzureLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException) or isinstance(e, BadRequestError): raise e if raise_exception: diff --git a/apps/models_provider/impl/azure_model_provider/credential/stt.py b/apps/models_provider/impl/azure_model_provider/credential/stt.py index cd115473f..ac82e00d4 100644 --- a/apps/models_provider/impl/azure_model_provider/credential/stt.py +++ b/apps/models_provider/impl/azure_model_provider/credential/stt.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential): api_version = forms.TextInputField("API Version", required=True) @@ -32,7 +31,7 @@ class AzureOpenAISTTModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/azure_model_provider/credential/tti.py b/apps/models_provider/impl/azure_model_provider/credential/tti.py index a4eef6b61..c370eaa4e 100644 --- a/apps/models_provider/impl/azure_model_provider/credential/tti.py +++ b/apps/models_provider/impl/azure_model_provider/credential/tti.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class AzureOpenAITTIModelParams(BaseForm): size = forms.SingleSelect( @@ -67,7 +66,7 @@ class AzureOpenAITextToImageModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/azure_model_provider/credential/tts.py b/apps/models_provider/impl/azure_model_provider/credential/tts.py index 7c575b900..c5e6318cf 100644 --- a/apps/models_provider/impl/azure_model_provider/credential/tts.py +++ b/apps/models_provider/impl/azure_model_provider/credential/tts.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class AzureOpenAITTSModelGeneralParams(BaseForm): # alloy, echo, fable, onyx, nova, shimmer @@ -50,7 +49,7 @@ class AzureOpenAITTSModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/azure_model_provider/model/azure_chat_model.py b/apps/models_provider/impl/azure_model_provider/model/azure_chat_model.py index 4c0b546ff..36a12553a 100644 --- a/apps/models_provider/impl/azure_model_provider/model/azure_chat_model.py +++ b/apps/models_provider/impl/azure_model_provider/model/azure_chat_model.py @@ -7,9 +7,11 @@ @desc: """ -from typing import List, Dict +from typing import List, Dict, Optional, Any +from langchain_core.language_models import LanguageModelInput from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.runnables import RunnableConfig from langchain_openai import AzureChatOpenAI from common.config.tokenizer_manage_config import TokenizerManage @@ -36,16 +38,42 @@ class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI): streaming=True, ) + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.__dict__.get('_last_generation_info') + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: try: - return super().get_num_tokens_from_messages(messages) + return self.get_last_generation_info().get('input_tokens', 0) except Exception as e: tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) def get_num_tokens(self, text: str) -> int: try: - return super().get_num_tokens(text) + return self.get_last_generation_info().get('output_tokens', 0) except Exception as e: tokenizer = TokenizerManage.get_tokenizer() return len(tokenizer.encode(text)) + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[list[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + message = super().invoke(input, config, stop=stop, **kwargs) + if isinstance(message.content, str): + return message + elif isinstance(message.content, list): + # 构造新的响应消息返回 + content = message.content + normalized_parts = [] + for item in content: + if isinstance(item, dict): + if item.get('type') == 'text': + normalized_parts.append(item.get('text', '')) + message.content = ''.join(normalized_parts) + self.__dict__.setdefault('_last_generation_info', {}).update(message.usage_metadata) + return message diff --git a/apps/models_provider/impl/base_chat_open_ai.py b/apps/models_provider/impl/base_chat_open_ai.py index 6c6698d12..2d4119f55 100644 --- a/apps/models_provider/impl/base_chat_open_ai.py +++ b/apps/models_provider/impl/base_chat_open_ai.py @@ -1,4 +1,5 @@ # coding=utf-8 +import base64 from concurrent.futures import ThreadPoolExecutor from requests.exceptions import ConnectTimeout, ReadTimeout from typing import Dict, Optional, Any, Iterator, cast, Union, Sequence, Callable, Mapping @@ -15,7 +16,7 @@ from langchain_openai import ChatOpenAI from langchain_openai.chat_models.base import _create_usage_metadata from common.config.tokenizer_manage_config import TokenizerManage - +from common.utils.logger import maxkb_logger def custom_get_token_ids(text: str): tokenizer = TokenizerManage.get_tokenizer() @@ -102,13 +103,13 @@ class BaseChatOpenAI(ChatOpenAI): future = executor.submit(super().get_num_tokens_from_messages, messages, tools) try: response = future.result() - print("请求成功(未超时)") + maxkb_logger.info("请求成功(未超时)") return response except Exception as e: if isinstance(e, ReadTimeout): raise # 继续抛出 else: - print("except:", e) + maxkb_logger.error("except:", e) tokenizer = TokenizerManage.get_tokenizer() return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) @@ -211,3 +212,20 @@ class BaseChatOpenAI(ChatOpenAI): self.usage_metadata = chat_result.response_metadata[ 'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata return chat_result + + + def upload_file_and_get_url(self, file_stream, file_name): + """上传文件并获取文件URL""" + base64_video = base64.b64encode(file_stream).decode("utf-8") + video_format = get_video_format(file_name) + return f'data:{video_format};base64,{base64_video}' + +def get_video_format(file_name): + extension = file_name.split('.')[-1].lower() + format_map = { + 'mp4': 'video/mp4', + 'avi': 'video/avi', + 'mov': 'video/mov', + 'wmv': 'video/x-ms-wmv' + } + return format_map.get(extension, 'video/mp4') \ No newline at end of file diff --git a/apps/models_provider/impl/deepseek_model_provider/credential/llm.py b/apps/models_provider/impl/deepseek_model_provider/credential/llm.py index 7943c4077..d06997dba 100644 --- a/apps/models_provider/impl/deepseek_model_provider/credential/llm.py +++ b/apps/models_provider/impl/deepseek_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/11 17:51 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -16,7 +15,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class DeepSeekLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -56,7 +55,7 @@ class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/gemini_model_provider/credential/embedding.py b/apps/models_provider/impl/gemini_model_provider/credential/embedding.py index 5d88b9fc1..126398495 100644 --- a/apps/models_provider/impl/gemini_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/gemini_model_provider/credential/embedding.py @@ -6,7 +6,6 @@ @date:2024/7/12 16:45 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -15,7 +14,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class GeminiEmbeddingCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -35,7 +34,7 @@ class GeminiEmbeddingCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.embed_query(_('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/gemini_model_provider/credential/image.py b/apps/models_provider/impl/gemini_model_provider/credential/image.py index 026ed2e80..4539957b6 100644 --- a/apps/models_provider/impl/gemini_model_provider/credential/image.py +++ b/apps/models_provider/impl/gemini_model_provider/credential/image.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -53,7 +52,7 @@ class GeminiImageModelCredential(BaseForm, BaseModelCredential): for chunk in res: maxkb_logger.info(chunk) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/gemini_model_provider/credential/llm.py b/apps/models_provider/impl/gemini_model_provider/credential/llm.py index f369eb309..d1eb5b876 100644 --- a/apps/models_provider/impl/gemini_model_provider/credential/llm.py +++ b/apps/models_provider/impl/gemini_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/11 17:57 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -16,7 +15,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class GeminiLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -56,7 +55,7 @@ class GeminiLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/gemini_model_provider/credential/stt.py b/apps/models_provider/impl/gemini_model_provider/credential/stt.py index 475cf4c2e..470185fa2 100644 --- a/apps/models_provider/impl/gemini_model_provider/credential/stt.py +++ b/apps/models_provider/impl/gemini_model_provider/credential/stt.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class GeminiSTTModelCredential(BaseForm, BaseModelCredential): api_key = forms.PasswordInputField('API Key', required=True) @@ -30,7 +29,7 @@ class GeminiSTTModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/gemini_model_provider/model/image.py b/apps/models_provider/impl/gemini_model_provider/model/image.py index 1f4e97a18..e4d605c1c 100644 --- a/apps/models_provider/impl/gemini_model_provider/model/image.py +++ b/apps/models_provider/impl/gemini_model_provider/model/image.py @@ -23,6 +23,5 @@ class GeminiImage(MaxKBBaseModel, ChatGoogleGenerativeAI): return GeminiImage( model=model_name, google_api_key=model_credential.get('api_key'), - streaming=True, **optional_params, ) diff --git a/apps/models_provider/impl/kimi_model_provider/credential/llm.py b/apps/models_provider/impl/kimi_model_provider/credential/llm.py index 7c2e4b174..d451e39f3 100644 --- a/apps/models_provider/impl/kimi_model_provider/credential/llm.py +++ b/apps/models_provider/impl/kimi_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/11 18:06 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -16,7 +15,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class KimiLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -56,7 +55,7 @@ class KimiLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/local_model_provider/credential/embedding/__init__.py b/apps/models_provider/impl/local_model_provider/credential/embedding/__init__.py new file mode 100644 index 000000000..29828bb74 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/embedding/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/7 14:02 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/models_provider/impl/local_model_provider/credential/embedding.py b/apps/models_provider/impl/local_model_provider/credential/embedding/model.py similarity index 91% rename from apps/models_provider/impl/local_model_provider/credential/embedding.py rename to apps/models_provider/impl/local_model_provider/credential/embedding/model.py index 9d656ad98..402c48c12 100644 --- a/apps/models_provider/impl/local_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/local_model_provider/credential/embedding/model.py @@ -1,12 +1,11 @@ # coding=utf-8 """ @project: MaxKB - @Author:虎 - @file: embedding.py - @date:2024/7/11 11:06 + @Author:虎虎 + @file: model.py.py + @date:2025/11/7 14:02 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -16,7 +15,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.local_model_provider.model.embedding import LocalEmbedding - +from common.utils.logger import maxkb_logger class LocalEmbeddingCredential(BaseForm, BaseModelCredential): @@ -35,7 +34,7 @@ class LocalEmbeddingCredential(BaseForm, BaseModelCredential): model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential) model.embed_query(gettext('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/local_model_provider/credential/embedding/web.py b/apps/models_provider/impl/local_model_provider/credential/embedding/web.py new file mode 100644 index 000000000..4695d141c --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/embedding/web.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/7 14:03 + @desc: +""" +from typing import Dict + +import requests +from django.utils.translation import gettext_lazy as _ + +from common import forms +from common.forms import BaseForm +from maxkb.const import CONFIG +from models_provider.base_model_provider import BaseModelCredential + + +class LocalEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + prefix = CONFIG.get_admin_path() + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate', + json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) + + def encryption_dict(self, model: Dict[str, object]): + return model + + cache_folder = forms.TextInputField(_('Model catalog'), required=True) diff --git a/apps/models_provider/impl/local_model_provider/credential/reranker/__init__.py b/apps/models_provider/impl/local_model_provider/credential/reranker/__init__.py new file mode 100644 index 000000000..f9ec12bc5 --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/reranker/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: __init__.py.py + @date:2025/11/7 14:22 + @desc: +""" +import os + +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + from .model import * +else: + from .web import * diff --git a/apps/models_provider/impl/local_model_provider/credential/reranker.py b/apps/models_provider/impl/local_model_provider/credential/reranker/model.py similarity index 91% rename from apps/models_provider/impl/local_model_provider/credential/reranker.py rename to apps/models_provider/impl/local_model_provider/credential/reranker/model.py index 46f8ebca2..3c6fa4e32 100644 --- a/apps/models_provider/impl/local_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/local_model_provider/credential/reranker/model.py @@ -1,12 +1,11 @@ # coding=utf-8 """ @project: MaxKB - @Author:虎 - @file: reranker.py - @date:2024/9/3 14:33 + @Author:虎虎 + @file: model.py + @date:2025/11/7 14:23 @desc: """ -import traceback from typing import Dict from langchain_core.documents import Document @@ -17,7 +16,7 @@ from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.local_model_provider.model.reranker import LocalReranker from django.utils.translation import gettext_lazy as _, gettext - +from common.utils.logger import maxkb_logger class LocalRerankerCredential(BaseForm, BaseModelCredential): @@ -36,7 +35,7 @@ class LocalRerankerCredential(BaseForm, BaseModelCredential): model: LocalReranker = provider.get_model(model_type, model_name, model_credential) model.compress_documents([Document(page_content=gettext('Hello'))], gettext('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/local_model_provider/credential/reranker/web.py b/apps/models_provider/impl/local_model_provider/credential/reranker/web.py new file mode 100644 index 000000000..bc86982bf --- /dev/null +++ b/apps/models_provider/impl/local_model_provider/credential/reranker/web.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎虎 + @file: web.py + @date:2025/11/7 14:23 + @desc: +""" +from typing import Dict + +import requests +from django.utils.translation import gettext_lazy as _ + +from common import forms +from common.forms import BaseForm +from maxkb.const import CONFIG +from models_provider.base_model_provider import BaseModelCredential + + +class LocalRerankerCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + prefix = CONFIG.get_admin_path() + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}{prefix}/api/model/validate', + json={'model_name': model_name, 'model_type': model_type, 'model_credential': model_credential}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) + + def encryption_dict(self, model: Dict[str, object]): + return model + + cache_folder = forms.TextInputField(_('Model catalog'), required=True) diff --git a/apps/models_provider/impl/local_model_provider/model/reranker/web.py b/apps/models_provider/impl/local_model_provider/model/reranker/web.py index 45ab6978a..8be3850be 100644 --- a/apps/models_provider/impl/local_model_provider/model/reranker/web.py +++ b/apps/models_provider/impl/local_model_provider/model/reranker/web.py @@ -32,7 +32,6 @@ class LocalReranker(MaxKBBaseModel, BaseModel, BaseDocumentCompressor): def __init__(self, **kwargs): super().__init__(**kwargs) - print('ssss', kwargs.get('model_id', None)) self.model_id = kwargs.get('model_id', None) def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ diff --git a/apps/models_provider/impl/openai_model_provider/credential/embedding.py b/apps/models_provider/impl/openai_model_provider/credential/embedding.py index b86285a4e..4be5f32a8 100644 --- a/apps/models_provider/impl/openai_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/openai_model_provider/credential/embedding.py @@ -6,7 +6,6 @@ @date:2024/7/12 16:45 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -15,6 +14,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +from common.utils.logger import maxkb_logger class OpenAIEmbeddingModelParams(BaseForm): dimensions = forms.SingleSelect( @@ -53,7 +53,7 @@ class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.embed_query(_('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/openai_model_provider/credential/image.py b/apps/models_provider/impl/openai_model_provider/credential/image.py index 3f9a6cc3a..071a8335b 100644 --- a/apps/models_provider/impl/openai_model_provider/credential/image.py +++ b/apps/models_provider/impl/openai_model_provider/credential/image.py @@ -1,7 +1,6 @@ # coding=utf-8 import base64 import os -import traceback from typing import Dict from langchain_core.messages import HumanMessage @@ -56,7 +55,7 @@ class OpenAIImageModelCredential(BaseForm, BaseModelCredential): for chunk in res: maxkb_logger.info(chunk) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/openai_model_provider/credential/llm.py b/apps/models_provider/impl/openai_model_provider/credential/llm.py index c97476c33..a2db9fe68 100644 --- a/apps/models_provider/impl/openai_model_provider/credential/llm.py +++ b/apps/models_provider/impl/openai_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/11 18:32 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -17,7 +16,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class OpenAILLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -58,7 +57,7 @@ class OpenAILLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException) or isinstance(e, BadRequestError): raise e if raise_exception: diff --git a/apps/models_provider/impl/openai_model_provider/credential/stt.py b/apps/models_provider/impl/openai_model_provider/credential/stt.py index b70785bc6..1675835c1 100644 --- a/apps/models_provider/impl/openai_model_provider/credential/stt.py +++ b/apps/models_provider/impl/openai_model_provider/credential/stt.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class OpenAISTTModelParams(BaseForm): language = forms.TextInputField( @@ -38,7 +37,7 @@ class OpenAISTTModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/openai_model_provider/credential/tti.py b/apps/models_provider/impl/openai_model_provider/credential/tti.py index e999f385c..ba8826742 100644 --- a/apps/models_provider/impl/openai_model_provider/credential/tti.py +++ b/apps/models_provider/impl/openai_model_provider/credential/tti.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class OpenAITTIModelParams(BaseForm): size = forms.SingleSelect( @@ -70,7 +69,7 @@ class OpenAITextToImageModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/openai_model_provider/credential/tts.py b/apps/models_provider/impl/openai_model_provider/credential/tts.py index b499f3506..6c70aca87 100644 --- a/apps/models_provider/impl/openai_model_provider/credential/tts.py +++ b/apps/models_provider/impl/openai_model_provider/credential/tts.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class OpenAITTSModelGeneralParams(BaseForm): # alloy, echo, fable, onyx, nova, shimmer @@ -49,7 +48,7 @@ class OpenAITTSModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/regolo_model_provider/credential/embedding.py b/apps/models_provider/impl/regolo_model_provider/credential/embedding.py index c935d2c54..3a447036c 100644 --- a/apps/models_provider/impl/regolo_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/regolo_model_provider/credential/embedding.py @@ -6,7 +6,6 @@ @date:2024/7/12 16:45 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -15,7 +14,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class RegoloEmbeddingCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -35,7 +34,7 @@ class RegoloEmbeddingCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.embed_query(_('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/regolo_model_provider/credential/image.py b/apps/models_provider/impl/regolo_model_provider/credential/image.py index 6f845fdc7..c45790319 100644 --- a/apps/models_provider/impl/regolo_model_provider/credential/image.py +++ b/apps/models_provider/impl/regolo_model_provider/credential/image.py @@ -1,7 +1,6 @@ # coding=utf-8 import base64 import os -import traceback from typing import Dict from langchain_core.messages import HumanMessage @@ -12,7 +11,7 @@ from common.forms import BaseForm, TooltipLabel from django.utils.translation import gettext_lazy as _, gettext from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class RegoloImageModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -53,10 +52,8 @@ class RegoloImageModelCredential(BaseForm, BaseModelCredential): try: model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.stream([HumanMessage(content=[{"type": "text", "text": gettext('Hello')}])]) - for chunk in res: - print(chunk) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/regolo_model_provider/credential/llm.py b/apps/models_provider/impl/regolo_model_provider/credential/llm.py index 46012ea76..d30bca684 100644 --- a/apps/models_provider/impl/regolo_model_provider/credential/llm.py +++ b/apps/models_provider/impl/regolo_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/11 18:32 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -16,7 +15,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class RegoloLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -57,7 +56,7 @@ class RegoloLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/regolo_model_provider/credential/tti.py b/apps/models_provider/impl/regolo_model_provider/credential/tti.py index 8f203cf2d..3fa0eca67 100644 --- a/apps/models_provider/impl/regolo_model_provider/credential/tti.py +++ b/apps/models_provider/impl/regolo_model_provider/credential/tti.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class RegoloTTIModelParams(BaseForm): size = forms.SingleSelect( @@ -68,9 +67,8 @@ class RegoloTextToImageModelCredential(BaseForm, BaseModelCredential): try: model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() - print(res) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/siliconCloud_model_provider/credential/embedding.py b/apps/models_provider/impl/siliconCloud_model_provider/credential/embedding.py index 0e9dc2e5f..92fa5778d 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/credential/embedding.py @@ -6,7 +6,6 @@ @date:2024/7/12 16:45 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -15,7 +14,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class SiliconCloudEmbeddingCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -35,7 +34,7 @@ class SiliconCloudEmbeddingCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.embed_query(_('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/siliconCloud_model_provider/credential/image.py b/apps/models_provider/impl/siliconCloud_model_provider/credential/image.py index 3c1664f5e..92b9f83e7 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/credential/image.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/credential/image.py @@ -1,7 +1,6 @@ # coding=utf-8 import base64 import os -import traceback from typing import Dict from langchain_core.messages import HumanMessage @@ -56,7 +55,7 @@ class SiliconCloudImageModelCredential(BaseForm, BaseModelCredential): for chunk in res: maxkb_logger.info(chunk) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/siliconCloud_model_provider/credential/llm.py b/apps/models_provider/impl/siliconCloud_model_provider/credential/llm.py index 903ffd7a2..e096ba2bc 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/credential/llm.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/11 18:32 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -16,7 +15,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class SiliconCloudLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -57,7 +56,7 @@ class SiliconCloudLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/siliconCloud_model_provider/credential/reranker.py b/apps/models_provider/impl/siliconCloud_model_provider/credential/reranker.py index a7c94a0e2..7a0b17ba6 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/credential/reranker.py @@ -6,7 +6,6 @@ @date:2024/9/9 17:51 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -17,7 +16,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.siliconCloud_model_provider.model.reranker import SiliconCloudReranker - +from common.utils.logger import maxkb_logger class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential): @@ -36,7 +35,7 @@ class SiliconCloudRerankerCredential(BaseForm, BaseModelCredential): model: SiliconCloudReranker = provider.get_model(model_type, model_name, model_credential) model.compress_documents([Document(page_content=_('Hello'))], _('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py b/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py index 13e9cbe0e..83efa7d0f 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/credential/stt.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class SiliconCloudSTTModelCredential(BaseForm, BaseModelCredential): api_base = forms.TextInputField('API URL', required=True) @@ -31,7 +30,7 @@ class SiliconCloudSTTModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential,**model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/siliconCloud_model_provider/credential/tti.py b/apps/models_provider/impl/siliconCloud_model_provider/credential/tti.py index c90cf135a..6c252b05c 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/credential/tti.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/credential/tti.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class SiliconCloudTTIModelParams(BaseForm): size = forms.SingleSelect( @@ -70,7 +69,7 @@ class SiliconCloudTextToImageModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/siliconCloud_model_provider/credential/tts.py b/apps/models_provider/impl/siliconCloud_model_provider/credential/tts.py index e28048323..216b0f160 100644 --- a/apps/models_provider/impl/siliconCloud_model_provider/credential/tts.py +++ b/apps/models_provider/impl/siliconCloud_model_provider/credential/tts.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,6 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +from common.utils.logger import maxkb_logger class SiliconCloudTTSModelGeneralParams(BaseForm): # alloy, echo, fable, onyx, nova, shimmer @@ -50,7 +50,7 @@ class SiliconCloudTTSModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/tencent_cloud_model_provider/credential/llm.py b/apps/models_provider/impl/tencent_cloud_model_provider/credential/llm.py index 8612956f2..422e63726 100644 --- a/apps/models_provider/impl/tencent_cloud_model_provider/credential/llm.py +++ b/apps/models_provider/impl/tencent_cloud_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/11 18:32 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -16,7 +15,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class TencentCloudLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -57,7 +56,7 @@ class TencentCloudLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/tencent_cloud_model_provider/model/llm.py b/apps/models_provider/impl/tencent_cloud_model_provider/model/llm.py index dc962e491..9d097dcde 100644 --- a/apps/models_provider/impl/tencent_cloud_model_provider/model/llm.py +++ b/apps/models_provider/impl/tencent_cloud_model_provider/model/llm.py @@ -6,9 +6,7 @@ @date:2024/4/18 15:28 @desc: """ -from typing import List, Dict - -from langchain_core.messages import BaseMessage, get_buffer_string +from typing import Dict from common.config.tokenizer_manage_config import TokenizerManage from models_provider.base_model_provider import MaxKBBaseModel diff --git a/apps/models_provider/impl/tencent_model_provider/credential/embedding.py b/apps/models_provider/impl/tencent_model_provider/credential/embedding.py index 1b36f7251..4f31cc110 100644 --- a/apps/models_provider/impl/tencent_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/tencent_model_provider/credential/embedding.py @@ -1,4 +1,3 @@ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -7,7 +6,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class TencentEmbeddingCredential(BaseForm, BaseModelCredential): @@ -22,7 +21,7 @@ class TencentEmbeddingCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.embed_query(_('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/tencent_model_provider/credential/image.py b/apps/models_provider/impl/tencent_model_provider/credential/image.py index 7da20412d..e344ee5d3 100644 --- a/apps/models_provider/impl/tencent_model_provider/credential/image.py +++ b/apps/models_provider/impl/tencent_model_provider/credential/image.py @@ -6,7 +6,6 @@ @date:2024/7/11 18:41 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -58,7 +57,7 @@ class TencentVisionModelCredential(BaseForm, BaseModelCredential): for chunk in res: maxkb_logger.info(chunk) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/tencent_model_provider/credential/llm.py b/apps/models_provider/impl/tencent_model_provider/credential/llm.py index 0357c097b..b4d254ee0 100644 --- a/apps/models_provider/impl/tencent_model_provider/credential/llm.py +++ b/apps/models_provider/impl/tencent_model_provider/credential/llm.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from django.utils.translation import gettext_lazy as _, gettext from langchain_core.messages import HumanMessage @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class TencentLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -50,7 +49,7 @@ class TencentLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if raise_exception: raise AppApiException(ValidCode.valid_error.value, gettext( diff --git a/apps/models_provider/impl/tencent_model_provider/credential/stt.py b/apps/models_provider/impl/tencent_model_provider/credential/stt.py index 3eea500f2..a03684cb4 100644 --- a/apps/models_provider/impl/tencent_model_provider/credential/stt.py +++ b/apps/models_provider/impl/tencent_model_provider/credential/stt.py @@ -1,4 +1,3 @@ -import traceback from common import forms from common.exception.app_exception import AppApiException @@ -6,7 +5,7 @@ from common.forms import BaseForm, TooltipLabel from django.utils.translation import gettext_lazy as _, gettext from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class TencentSSTModelParams(BaseForm): EngSerViceType = forms.SingleSelect( @@ -71,7 +70,7 @@ class TencentSTTModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if raise_exception: raise AppApiException(ValidCode.valid_error.value, gettext( diff --git a/apps/models_provider/impl/tencent_model_provider/credential/tti.py b/apps/models_provider/impl/tencent_model_provider/credential/tti.py index 464c06b17..98df630b4 100644 --- a/apps/models_provider/impl/tencent_model_provider/credential/tti.py +++ b/apps/models_provider/impl/tencent_model_provider/credential/tti.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from django.utils.translation import gettext_lazy as _, gettext @@ -7,7 +6,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class TencentTTIModelParams(BaseForm): Style = forms.SingleSelect( @@ -97,7 +96,7 @@ class TencentTTIModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if raise_exception: raise AppApiException(ValidCode.valid_error.value, gettext( diff --git a/apps/models_provider/impl/vllm_model_provider/credential/embedding.py b/apps/models_provider/impl/vllm_model_provider/credential/embedding.py index 9ba9967d6..89a4f19e3 100644 --- a/apps/models_provider/impl/vllm_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/vllm_model_provider/credential/embedding.py @@ -6,7 +6,6 @@ @date:2024/7/12 16:45 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -15,7 +14,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class VllmEmbeddingCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -35,7 +34,7 @@ class VllmEmbeddingCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.embed_query(_('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/vllm_model_provider/credential/image.py b/apps/models_provider/impl/vllm_model_provider/credential/image.py index 663121810..d8a0b235a 100644 --- a/apps/models_provider/impl/vllm_model_provider/credential/image.py +++ b/apps/models_provider/impl/vllm_model_provider/credential/image.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -54,7 +53,7 @@ class VllmImageModelCredential(BaseForm, BaseModelCredential): for chunk in res: maxkb_logger.info(chunk) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/vllm_model_provider/credential/llm.py b/apps/models_provider/impl/vllm_model_provider/credential/llm.py index 02b1b9a67..e15d858e2 100644 --- a/apps/models_provider/impl/vllm_model_provider/credential/llm.py +++ b/apps/models_provider/impl/vllm_model_provider/credential/llm.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -9,7 +8,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class VLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -49,7 +48,7 @@ class VLLMModelCredential(BaseForm, BaseModelCredential): try: res = model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) raise AppApiException(ValidCode.valid_error.value, gettext( 'Verification failed, please check whether the parameters are correct: {error}').format( diff --git a/apps/models_provider/impl/vllm_model_provider/credential/reranker.py b/apps/models_provider/impl/vllm_model_provider/credential/reranker.py index a2ae71e73..881c85179 100644 --- a/apps/models_provider/impl/vllm_model_provider/credential/reranker.py +++ b/apps/models_provider/impl/vllm_model_provider/credential/reranker.py @@ -1,4 +1,3 @@ -import traceback from typing import Dict from langchain_core.documents import Document @@ -10,7 +9,7 @@ from models_provider.base_model_provider import BaseModelCredential, ValidCode from django.utils.translation import gettext_lazy as _ from models_provider.impl.vllm_model_provider.model.reranker import VllmBgeReranker - +from common.utils.logger import maxkb_logger class VllmRerankerCredential(BaseForm, BaseModelCredential): api_url = forms.TextInputField('API URL', required=True) @@ -34,7 +33,7 @@ class VllmRerankerCredential(BaseForm, BaseModelCredential): test_text = str(_('Hello')) model.compress_documents([Document(page_content=test_text)], test_text) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/vllm_model_provider/credential/whisper_stt.py b/apps/models_provider/impl/vllm_model_provider/credential/whisper_stt.py index f65b38eac..5844d0a4d 100644 --- a/apps/models_provider/impl/vllm_model_provider/credential/whisper_stt.py +++ b/apps/models_provider/impl/vllm_model_provider/credential/whisper_stt.py @@ -13,7 +13,7 @@ from models_provider.base_model_provider import BaseModelCredential, ValidCode class VLLMWhisperModelParams(BaseForm): Language = forms.TextInputField( - TooltipLabel(_('Language'), + TooltipLabel(_('language'), _("If not passed, the default value is 'zh'")), required=True, default_value='zh', diff --git a/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py b/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py index 922d934a8..ca502c4b0 100644 --- a/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py +++ b/apps/models_provider/impl/vllm_model_provider/model/whisper_sst.py @@ -52,11 +52,11 @@ class VllmWhisperSpeechToText(MaxKBBaseModel, BaseSpeechToText): api_key=self.api_key, base_url=base_url ) - + buf = audio_file.read() filter_params = {k: v for k, v in self.params.items() if k not in {'model_id', 'use_local', 'streaming'}} transcription_params = { 'model': self.model, - 'file': audio_file, + 'file': buf, 'language': 'zh', } result = client.audio.transcriptions.create( diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py b/apps/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py index e2950940c..5cda8fd88 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py @@ -6,7 +6,6 @@ @date:2024/7/12 16:45 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -15,7 +14,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class VolcanicEmbeddingCredential(BaseForm, BaseModelCredential): def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, @@ -35,7 +34,7 @@ class VolcanicEmbeddingCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.embed_query(_('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/credential/image.py b/apps/models_provider/impl/volcanic_engine_model_provider/credential/image.py index ee44b35cc..0801ff62e 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/credential/image.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/credential/image.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -54,7 +53,7 @@ class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential): for chunk in res: maxkb_logger.info(chunk) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/credential/llm.py b/apps/models_provider/impl/volcanic_engine_model_provider/credential/llm.py index bc15cff4c..7d330fa98 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/credential/llm.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/11 17:57 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -16,7 +15,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class VolcanicEngineLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -56,7 +55,7 @@ class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py b/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py index 12c18325f..60fed88d0 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/credential/stt.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -8,6 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode +from common.utils.logger import maxkb_logger class VolcanicEngineSTTModelParams(BaseForm): uid = forms.TextInputField( @@ -42,7 +42,7 @@ class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/credential/tti.py b/apps/models_provider/impl/volcanic_engine_model_provider/credential/tti.py index 480dcff8e..c8269a0a1 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/credential/tti.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/credential/tti.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class VolcanicEngineTTIModelGeneralParams(BaseForm): size = forms.SingleSelect( @@ -52,7 +51,7 @@ class VolcanicEngineTTIModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/credential/tts.py b/apps/models_provider/impl/volcanic_engine_model_provider/credential/tts.py index 2bef1211d..2fda2db37 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/credential/tts.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/credential/tts.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class VolcanicEngineTTSModelGeneralParams(BaseForm): voice_type = forms.SingleSelect( @@ -60,7 +59,7 @@ class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/credential/ttv.py b/apps/models_provider/impl/volcanic_engine_model_provider/credential/ttv.py index bdb92a63d..304b415d7 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/credential/ttv.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/credential/ttv.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -9,7 +8,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel, SingleSelect, TextInputField from common.forms.switch_field import SwitchField from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class VolcanicEngineTTVModelGeneralParams(BaseForm): resolution = SingleSelect( @@ -72,7 +71,7 @@ class VolcanicEngineTTVModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/volcanic_engine_model_provider/model/image.py b/apps/models_provider/impl/volcanic_engine_model_provider/model/image.py index 05ce9d933..73699d826 100644 --- a/apps/models_provider/impl/volcanic_engine_model_provider/model/image.py +++ b/apps/models_provider/impl/volcanic_engine_model_provider/model/image.py @@ -25,20 +25,5 @@ class VolcanicEngineImage(MaxKBBaseModel, BaseChatOpenAI): def is_cache_model(): return False - def upload_file_and_get_url(self, file_stream, file_name): - """上传文件并获取文件URL""" - base64_video = base64.b64encode(file_stream).decode("utf-8") - video_format = get_video_format(file_name) - return f'data:{video_format};base64,{base64_video}' - -def get_video_format(file_name): - extension = file_name.split('.')[-1].lower() - format_map = { - 'mp4': 'video/mp4', - 'avi': 'video/avi', - 'mov': 'video/mov', - 'wmv': 'video/x-ms-wmv' - } - return format_map.get(extension, 'video/mp4') diff --git a/apps/models_provider/impl/wenxin_model_provider/credential/embedding.py b/apps/models_provider/impl/wenxin_model_provider/credential/embedding.py index a0fc15d13..b250dffe9 100644 --- a/apps/models_provider/impl/wenxin_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/wenxin_model_provider/credential/embedding.py @@ -6,7 +6,6 @@ @date:2024/10/17 15:40 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -15,7 +14,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class QianfanEmbeddingCredential(BaseForm, BaseModelCredential): @@ -46,7 +45,7 @@ class QianfanEmbeddingCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.embed_query(_('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/wenxin_model_provider/credential/llm.py b/apps/models_provider/impl/wenxin_model_provider/credential/llm.py index 48a636a86..0f30adc6b 100644 --- a/apps/models_provider/impl/wenxin_model_provider/credential/llm.py +++ b/apps/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/12 10:19 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -16,7 +15,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class WenxinLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -66,7 +65,7 @@ class WenxinLLMModelCredential(BaseForm, BaseModelCredential): model.invoke( [HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) raise e return True diff --git a/apps/models_provider/impl/xf_model_provider/credential/embedding.py b/apps/models_provider/impl/xf_model_provider/credential/embedding.py index d945da82a..4227f29f8 100644 --- a/apps/models_provider/impl/xf_model_provider/credential/embedding.py +++ b/apps/models_provider/impl/xf_model_provider/credential/embedding.py @@ -6,7 +6,6 @@ @date:2024/10/17 15:40 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -15,7 +14,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class XFEmbeddingCredential(BaseForm, BaseModelCredential): @@ -30,7 +29,7 @@ class XFEmbeddingCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential) model.embed_query(_('Hello')) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/xf_model_provider/credential/image.py b/apps/models_provider/impl/xf_model_provider/credential/image.py index e952ea349..7336d1355 100644 --- a/apps/models_provider/impl/xf_model_provider/credential/image.py +++ b/apps/models_provider/impl/xf_model_provider/credential/image.py @@ -1,7 +1,6 @@ # coding=utf-8 import base64 import os -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -12,7 +11,7 @@ from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode from models_provider.impl.xf_model_provider.model.image import ImageMessage - +from common.utils.logger import maxkb_logger class XunFeiImageModelCredential(BaseForm, BaseModelCredential): spark_api_url = forms.TextInputField('API URL', required=True, @@ -42,7 +41,7 @@ class XunFeiImageModelCredential(BaseForm, BaseModelCredential): HumanMessage(_('Please outline this picture'))] model.stream(message_list) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/xf_model_provider/credential/llm.py b/apps/models_provider/impl/xf_model_provider/credential/llm.py index 5a3de3c1b..5fa6b2bbd 100644 --- a/apps/models_provider/impl/xf_model_provider/credential/llm.py +++ b/apps/models_provider/impl/xf_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/12 10:29 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -16,7 +15,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class XunFeiLLMModelGeneralParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -75,7 +74,7 @@ class XunFeiLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/xf_model_provider/credential/stt.py b/apps/models_provider/impl/xf_model_provider/credential/stt.py index 67da706ba..e9cee6a4f 100644 --- a/apps/models_provider/impl/xf_model_provider/credential/stt.py +++ b/apps/models_provider/impl/xf_model_provider/credential/stt.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class XunFeiSTTModelParams(BaseForm): language = forms.TextInputField( @@ -51,7 +50,7 @@ class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/xf_model_provider/credential/tts.py b/apps/models_provider/impl/xf_model_provider/credential/tts.py index 68a481b29..121c7be91 100644 --- a/apps/models_provider/impl/xf_model_provider/credential/tts.py +++ b/apps/models_provider/impl/xf_model_provider/credential/tts.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class XunFeiTTSModelGeneralParams(BaseForm): vcn = forms.SingleSelect( @@ -56,7 +55,7 @@ class XunFeiTTSModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/xf_model_provider/credential/zh_en_stt.py b/apps/models_provider/impl/xf_model_provider/credential/zh_en_stt.py index ebf5ce29c..b4f90e72e 100644 --- a/apps/models_provider/impl/xf_model_provider/credential/zh_en_stt.py +++ b/apps/models_provider/impl/xf_model_provider/credential/zh_en_stt.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext as _ @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class ZhEnXunFeiSTTModelCredential(BaseForm, BaseModelCredential): @@ -38,7 +37,7 @@ class ZhEnXunFeiSTTModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/xf_model_provider/model/stt.py b/apps/models_provider/impl/xf_model_provider/model/stt.py index b43320746..68f624961 100644 --- a/apps/models_provider/impl/xf_model_provider/model/stt.py +++ b/apps/models_provider/impl/xf_model_provider/model/stt.py @@ -159,7 +159,6 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): "audio": str(base64.b64encode(buf), 'utf-8'), "encoding": "lame"} } - print(d) d = json.dumps(d) await ws.send(d) status = STATUS_CONTINUE_FRAME diff --git a/apps/models_provider/impl/zhipu_model_provider/credential/image.py b/apps/models_provider/impl/zhipu_model_provider/credential/image.py index a4baf259a..3fb4a7258 100644 --- a/apps/models_provider/impl/zhipu_model_provider/credential/image.py +++ b/apps/models_provider/impl/zhipu_model_provider/credential/image.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -53,7 +52,7 @@ class ZhiPuImageModelCredential(BaseForm, BaseModelCredential): for chunk in res: maxkb_logger.info(chunk) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/zhipu_model_provider/credential/llm.py b/apps/models_provider/impl/zhipu_model_provider/credential/llm.py index 9be8a4beb..dea09abdb 100644 --- a/apps/models_provider/impl/zhipu_model_provider/credential/llm.py +++ b/apps/models_provider/impl/zhipu_model_provider/credential/llm.py @@ -6,7 +6,6 @@ @date:2024/7/12 10:46 @desc: """ -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -16,7 +15,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class ZhiPuLLMModelParams(BaseForm): temperature = forms.SliderField(TooltipLabel(_('Temperature'), @@ -55,7 +54,7 @@ class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) model.invoke([HumanMessage(content=gettext('Hello'))]) except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/zhipu_model_provider/credential/tti.py b/apps/models_provider/impl/zhipu_model_provider/credential/tti.py index 9cca94dab..2767bf8af 100644 --- a/apps/models_provider/impl/zhipu_model_provider/credential/tti.py +++ b/apps/models_provider/impl/zhipu_model_provider/credential/tti.py @@ -1,5 +1,4 @@ # coding=utf-8 -import traceback from typing import Dict from django.utils.translation import gettext_lazy as _, gettext @@ -8,7 +7,7 @@ from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm, TooltipLabel from models_provider.base_model_provider import BaseModelCredential, ValidCode - +from common.utils.logger import maxkb_logger class ZhiPuTTIModelParams(BaseForm): size = forms.SingleSelect( @@ -49,7 +48,7 @@ class ZhiPuTextToImageModelCredential(BaseForm, BaseModelCredential): model = provider.get_model(model_type, model_name, model_credential, **model_params) res = model.check_auth() except Exception as e: - traceback.print_exc() + maxkb_logger.error(f'Exception: {e}', exc_info=True) if isinstance(e, AppApiException): raise e if raise_exception: diff --git a/apps/models_provider/impl/zhipu_model_provider/model/image.py b/apps/models_provider/impl/zhipu_model_provider/model/image.py index d470af1ba..cb4bae6b4 100644 --- a/apps/models_provider/impl/zhipu_model_provider/model/image.py +++ b/apps/models_provider/impl/zhipu_model_provider/model/image.py @@ -3,7 +3,6 @@ from typing import Dict from models_provider.base_model_provider import MaxKBBaseModel from models_provider.impl.base_chat_open_ai import BaseChatOpenAI - class ZhiPuImage(MaxKBBaseModel, BaseChatOpenAI): @staticmethod diff --git a/apps/ops/celery/signal_handler.py b/apps/ops/celery/signal_handler.py index 91cba3ed0..d592803fc 100644 --- a/apps/ops/celery/signal_handler.py +++ b/apps/ops/celery/signal_handler.py @@ -17,12 +17,28 @@ from .logger import CeleryThreadTaskFileHandler logger = logging.getLogger(__file__) safe_str = lambda x: x +def init_scheduler(): + from common import job + + job.run() + + try: + from xpack import job as xpack_job + + xpack_job.run() + except ImportError: + pass + + @worker_ready.connect def on_app_ready(sender=None, headers=None, **kwargs): if cache.get("CELERY_APP_READY", 0) == 1: return cache.set("CELERY_APP_READY", 1, 10) + # 初始化定时任务 + init_scheduler() + tasks = get_after_app_ready_tasks() logger.debug("Work ready signal recv") logger.debug("Start need start task: [{}]".format(", ".join(tasks))) diff --git a/apps/oss/retrieval_urls.py b/apps/oss/retrieval_urls.py index 77687a646..816c242ee 100644 --- a/apps/oss/retrieval_urls.py +++ b/apps/oss/retrieval_urls.py @@ -17,5 +17,7 @@ urlpatterns = [ views.FileRetrievalView.as_view()), re_path(rf'oss/file/(?P[\w-]+)/?$', views.FileRetrievalView.as_view()), + re_path(rf'^/oss/get_url/(?P[\w-]+)?$', + views.GetUrlView.as_view()), ] diff --git a/apps/oss/urls.py b/apps/oss/urls.py index f344049f1..3537b026b 100644 --- a/apps/oss/urls.py +++ b/apps/oss/urls.py @@ -6,4 +6,5 @@ app_name = 'oss' urlpatterns = [ path('oss/file', views.FileView.as_view()), + path('oss/get_url', views.GetUrlView.as_view()), ] diff --git a/apps/oss/views/file.py b/apps/oss/views/file.py index a34dbca28..b9f6ca67d 100644 --- a/apps/oss/views/file.py +++ b/apps/oss/views/file.py @@ -1,4 +1,7 @@ # coding=utf-8 +import base64 + +import requests from django.utils.translation import gettext_lazy as _ from drf_spectacular.utils import extend_schema from rest_framework.parsers import MultiPartParser @@ -66,3 +69,33 @@ class FileView(APIView): @log(menu='file', operate='Delete file') def delete(self, request: Request, file_id: str): return result.success(FileSerializer.Operate(data={'id': file_id}).delete()) + + +class GetUrlView(APIView): + authentication_classes = [TokenAuth] + + @extend_schema( + methods=['GET'], + summary=_('Get url'), + description=_('Get url'), + operation_id=_('Get url'), # type: ignore + tags=[_('Chat')] # type: ignore + ) + def get(self, request: Request): + url = request.query_params.get('url') + response = requests.get(url) + # 返回状态码 响应内容大小 响应的contenttype 还有字节流 + content_type = response.headers.get('Content-Type', '') + # 根据内容类型决定如何处理 + if 'text' in content_type or 'json' in content_type: + content = response.text + else: + # 二进制内容使用Base64编码 + content = base64.b64encode(response.content).decode('utf-8') + + return result.success({ + 'status_code': response.status_code, + 'Content-Length': response.headers.get('Content-Length', 0), + 'Content-Type': content_type, + 'content': content, + }) diff --git a/apps/users/apps.py b/apps/users/apps.py index 72b140106..c115a830a 100644 --- a/apps/users/apps.py +++ b/apps/users/apps.py @@ -4,3 +4,6 @@ from django.apps import AppConfig class UsersConfig(AppConfig): default_auto_field = 'django.db.models.BigAutoField' name = 'users' + + def ready(self): + from ops.celery import signal_handler \ No newline at end of file diff --git a/apps/users/serializers/user.py b/apps/users/serializers/user.py index c850c11bf..7e3775969 100644 --- a/apps/users/serializers/user.py +++ b/apps/users/serializers/user.py @@ -139,12 +139,12 @@ class UserManageSerializer(serializers.Serializer): username = serializers.CharField( required=True, label=_("Username"), - max_length=20, + max_length=64, min_length=4, validators=[ validators.RegexValidator( - regex=re.compile("^.{4,20}$"), - message=_('Username must be 4-20 characters long') + regex=re.compile("^.{4,64}$"), + message=_('Username must be 4-64 characters long') ) ] ) @@ -165,7 +165,7 @@ class UserManageSerializer(serializers.Serializer): nick_name = serializers.CharField( required=True, label=_("Nick name"), - max_length=20, + max_length=64, ) phone = serializers.CharField( required=False, @@ -203,13 +203,13 @@ class UserManageSerializer(serializers.Serializer): username = serializers.CharField( required=False, label=_("Username"), - max_length=20, + max_length=64, allow_blank=True ) nick_name = serializers.CharField( required=False, label=_("Nick Name"), - max_length=20, + max_length=64, allow_blank=True ) email = serializers.CharField( @@ -360,7 +360,7 @@ class UserManageSerializer(serializers.Serializer): nick_name = serializers.CharField( required=False, label=_("Name"), - max_length=20, + max_length=64, ) phone = serializers.CharField( required=False, diff --git a/installer/Dockerfile b/installer/Dockerfile index 6caa63669..34aae093a 100644 --- a/installer/Dockerfile +++ b/installer/Dockerfile @@ -13,7 +13,7 @@ RUN apt-get update && \ apt-get clean all && \ rm -rf /var/lib/apt/lists/* WORKDIR /opt/maxkb-app -RUN gcc -shared -fPIC -o /opt/maxkb-app/sandbox/sandbox.so /opt/maxkb-app/installer/sandbox.c -ldl && \ +RUN gcc -shared -fPIC -o ${MAXKB_SANDBOX_HOME}/sandbox.so /opt/maxkb-app/installer/sandbox.c -ldl && \ rm -rf /opt/maxkb-app/ui && \ pip install uv --break-system-packages && \ python -m uv pip install -r pyproject.toml && \ @@ -47,13 +47,10 @@ ENV MAXKB_VERSION="${DOCKER_IMAGE_TAG} (build at ${BUILD_AT}, commit: ${GITHUB_C MAXKB_LOCAL_MODEL_PROTOCOL=http \ PIP_TARGET=/opt/maxkb/python-packages - WORKDIR /opt/maxkb-app COPY --from=stage-build /opt/maxkb-app /opt/maxkb-app COPY --from=stage-build /opt/py3 /opt/py3 -RUN chmod 755 /tmp - EXPOSE 8080 VOLUME /opt/maxkb ENTRYPOINT ["bash", "-c"] diff --git a/installer/Dockerfile-base b/installer/Dockerfile-base index ebf6c00ce..4e4eccb9f 100644 --- a/installer/Dockerfile-base +++ b/installer/Dockerfile-base @@ -1,7 +1,7 @@ FROM python:3.11-slim-trixie AS python-stage RUN python3 -m venv /opt/py3 -FROM ghcr.io/1panel-dev/maxkb-vector-model:v2.0.2 AS vector-model +FROM ghcr.io/1panel-dev/maxkb-vector-model:v2.0.3 AS vector-model FROM postgres:17.6-trixie COPY --from=python-stage /usr/local /usr/local @@ -26,12 +26,13 @@ RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ curl -L --connect-timeout 120 -m 1800 https://resource.fit2cloud.com/maxkb/ffmpeg/get-ffmpeg-linux | sh && \ mkdir -p /opt/maxkb-app/sandbox && \ useradd --no-create-home --home /opt/maxkb-app/sandbox sandbox -g root && \ - chown -R sandbox:root /opt/maxkb-app/sandbox && \ + chown -R sandbox:root /opt/maxkb-app/sandbox && chmod 550 /opt/maxkb-app/sandbox && \ chmod g-xr /usr/local/bin/* /usr/bin/* /bin/* /usr/sbin/* /sbin/* /usr/lib/postgresql/17/bin/* && \ chmod g+xr /usr/bin/ld.so && \ chmod g+x /usr/local/bin/python* && \ + chmod -R g-rwx /tmp /var/tmp /var/lock && \ apt-get clean all && \ - rm -rf /var/lib/apt/lists/* /usr/share/doc/* /usr/share/man/* /usr/share/info/* /usr/share/locale/* /usr/share/lintian/* /usr/share/linda/* /var/cache/* /var/log/* /var/tmp/* /tmp/* + rm -rf /var/lib/postgresql /var/lib/apt/lists/* /usr/share/doc/* /usr/share/man/* /usr/share/info/* /usr/share/locale/* /usr/share/lintian/* /usr/share/linda/* /var/cache/* /var/log/* /var/tmp/* /tmp/* COPY --from=vector-model --chmod=700 /opt/maxkb-app/model /opt/maxkb-app/model ENV PATH=/opt/py3/bin:$PATH \ @@ -45,9 +46,10 @@ ENV PATH=/opt/py3/bin:$PATH \ MAXKB_CONFIG_TYPE=ENV \ MAXKB_LOG_LEVEL=INFO \ MAXKB_SANDBOX=1 \ + MAXKB_SANDBOX_HOME=/opt/maxkb-app/sandbox \ MAXKB_SANDBOX_PYTHON_PACKAGE_PATHS="/opt/py3/lib/python3.11/site-packages,/opt/maxkb-app/sandbox/python-packages,/opt/maxkb/python-packages" \ MAXKB_SANDBOX_PYTHON_BANNED_KEYWORDS="subprocess.,system(,exec(,execve(,pty.,eval(,compile(,shutil.,input(,__import__" \ - MAXKB_SANDBOX_PYTHON_BANNED_HOSTS="127.0.0.1,localhost,maxkb,pgsql,redis" \ + MAXKB_SANDBOX_PYTHON_BANNED_HOSTS="127.0.0.1,localhost,host.docker.internal,maxkb,pgsql,redis" \ MAXKB_ADMIN_PATH=/admin EXPOSE 6379 \ No newline at end of file diff --git a/installer/Dockerfile-vector-model b/installer/Dockerfile-vector-model index c73e03079..6001ace55 100644 --- a/installer/Dockerfile-vector-model +++ b/installer/Dockerfile-vector-model @@ -25,7 +25,10 @@ COPY --from=vector-model /opt/maxkb/app/model /opt/maxkb-app/model COPY --from=vector-model /opt/maxkb/app/model/base/hub /opt/maxkb-app/model/tokenizer COPY --from=tmp-stage1 model/tokenizer /opt/maxkb-app/model/tokenizer RUN rm -rf /opt/maxkb-app/model/embedding/shibing624_text2vec-base-chinese/onnx - +RUN apk add --update --no-cache curl && \ + mkdir -p openai-tiktoken-cl100k-base && \ + curl -Lf https://openaipublic.blob.core.windows.net/encodings/cl100k_base.tiktoken > openai-tiktoken-cl100k-base/cl100k_base.tiktoken && \ + mv -f openai-tiktoken-cl100k-base /opt/maxkb-app/model/tokenizer/ FROM scratch diff --git a/installer/sandbox.c b/installer/sandbox.c index a06507c9a..9d3a7c928 100644 --- a/installer/sandbox.c +++ b/installer/sandbox.c @@ -9,14 +9,51 @@ #include #include #include +#include +#include -static const char *ENV_NAME = "SANDBOX_BANNED_HOSTS"; +static const char *BANNED_FILE_NAME = ".SANDBOX_BANNED_HOSTS"; + +/** + * 从 .so 文件所在目录读取 .SANDBOX_BANNED_HOSTS 文件内容 + * 返回 malloc 出的字符串(需 free),读取失败则返回空字符串 + */ +static char *load_banned_hosts() { + Dl_info info; + if (dladdr((void *)load_banned_hosts, &info) == 0 || !info.dli_fname) { + fprintf(stderr, "[sandbox] ⚠️ Unable to locate shared object path — allowing all hosts\n"); + return strdup(""); + } + + char so_path[PATH_MAX]; + strncpy(so_path, info.dli_fname, sizeof(so_path)); + so_path[sizeof(so_path) - 1] = '\0'; + + char *dir = dirname(so_path); + char file_path[PATH_MAX]; + snprintf(file_path, sizeof(file_path), "%s/%s", dir, BANNED_FILE_NAME); + + FILE *fp = fopen(file_path, "r"); + if (!fp) { + fprintf(stderr, "[sandbox] ⚠️ Cannot open %s — allowing all hosts\n", file_path); + return strdup(""); + } + + char *buf = malloc(4096); + if (!buf) { + fclose(fp); + fprintf(stderr, "[sandbox] ⚠️ Memory allocation failed — allowing all hosts\n"); + return strdup(""); + } + + size_t len = fread(buf, 1, 4095, fp); + buf[len] = '\0'; + fclose(fp); + return buf; +} /** * 精确匹配黑名单 - * target: 待检测字符串 - * env_val: 逗号分隔的黑名单列表 - * 返回 1 = 匹配,0 = 不匹配 */ static int match_env_patterns(const char *target, const char *env_val) { if (!target || !env_val || !*env_val) return 0; @@ -33,7 +70,6 @@ static int match_env_patterns(const char *target, const char *env_val) { if (*token) { regex_t regex; - // 精确匹配,加 ^ 和 $,忽略大小写 char fullpattern[512]; snprintf(fullpattern, sizeof(fullpattern), "^%s$", token); @@ -48,7 +84,6 @@ static int match_env_patterns(const char *target, const char *env_val) { fprintf(stderr, "[sandbox] ⚠️ Invalid regex '%s' — allowing host by default\n", token); } } - token = strtok(NULL, ","); } @@ -62,7 +97,8 @@ int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { if (!real_connect) real_connect = dlsym(RTLD_NEXT, "connect"); - const char *banned_env = getenv(ENV_NAME); + static char *banned_env = NULL; + if (!banned_env) banned_env = load_banned_hosts(); char ip[INET6_ADDRSTRLEN] = {0}; if (addr->sa_family == AF_INET) @@ -70,16 +106,16 @@ int connect(int sockfd, const struct sockaddr *addr, socklen_t addrlen) { else if (addr->sa_family == AF_INET6) inet_ntop(AF_INET6, &((struct sockaddr_in6 *)addr)->sin6_addr, ip, sizeof(ip)); - if (banned_env && match_env_patterns(ip, banned_env)) { + if (banned_env && *banned_env && match_env_patterns(ip, banned_env)) { fprintf(stderr, "[sandbox] 🚫 Access to host %s is banned\n", ip); - errno = EACCES; + errno = EACCES; // EACCES 的值是 13, 意思是 Permission denied return -1; } return real_connect(sockfd, addr, addrlen); } -/** 拦截 getaddrinfo() —— 精确匹配域名 */ +/** 拦截 getaddrinfo() —— 只拦截域名,不拦截纯 IP */ int getaddrinfo(const char *node, const char *service, const struct addrinfo *hints, struct addrinfo **res) { static int (*real_getaddrinfo)(const char *, const char *, @@ -87,11 +123,21 @@ int getaddrinfo(const char *node, const char *service, if (!real_getaddrinfo) real_getaddrinfo = dlsym(RTLD_NEXT, "getaddrinfo"); - const char *banned_env = getenv(ENV_NAME); + static char *banned_env = NULL; + if (!banned_env) banned_env = load_banned_hosts(); - if (banned_env && node && match_env_patterns(node, banned_env)) { - fprintf(stderr, "[sandbox] 🚫 Access to host %s is banned\n", node); - return EAI_FAIL; // 模拟 DNS 失败 + if (banned_env && *banned_env && node) { + // 检测 node 是否是 IP + struct in_addr ipv4; + struct in6_addr ipv6; + int is_ip = (inet_pton(AF_INET, node, &ipv4) == 1) || + (inet_pton(AF_INET6, node, &ipv6) == 1); + + // 只对“非IP的域名”进行屏蔽 + if (!is_ip && match_env_patterns(node, banned_env)) { + fprintf(stderr, "[sandbox] 🚫 Access to host %s is banned (DNS blocked)\n", node); + return EAI_FAIL; // 模拟 DNS 层禁止 + } } return real_getaddrinfo(node, service, hints, res); diff --git a/installer/start-maxkb.sh b/installer/start-maxkb.sh index c9beef809..73fc757e7 100644 --- a/installer/start-maxkb.sh +++ b/installer/start-maxkb.sh @@ -2,13 +2,13 @@ if [ ! -d /opt/maxkb/logs ]; then mkdir -p /opt/maxkb/logs - chmod 700 /opt/maxkb/logs fi +chmod -R 700 /opt/maxkb/logs if [ ! -d /opt/maxkb/local ]; then mkdir -p /opt/maxkb/local chmod 700 /opt/maxkb/local fi mkdir -p /opt/maxkb/python-packages -rm -f /opt/maxkb-app/tmp/*.pid +rm -f /opt/maxkb-app/tmp/* python /opt/maxkb-app/main.py start \ No newline at end of file diff --git a/main.py b/main.py index 5e4387c69..f738764fb 100644 --- a/main.py +++ b/main.py @@ -52,7 +52,7 @@ def start_services(): if args.worker: start_args.extend(['--worker', str(args.worker)]) else: - worker = os.environ.get('CORE_WORKER') + worker = os.environ.get('MAXKB_CORE_WORKER') if isinstance(worker, str) and worker.isdigit(): start_args.extend(['--worker', worker]) diff --git a/pyproject.toml b/pyproject.toml index 76a2cd195..9d4dd900f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "langchain-huggingface==0.3.0", "langchain-ollama==0.3.4", "langgraph==0.5.3", + "langchain_core==0.3.74", "torch==2.8.0", "sentence-transformers==5.0.0", "qianfan==0.4.12.3", diff --git a/ui/package.json b/ui/package.json index 47c414971..de0fa11ae 100644 --- a/ui/package.json +++ b/ui/package.json @@ -28,12 +28,12 @@ "cropperjs": "^1.6.2", "dingtalk-jsapi": "^3.1.0", "echarts": "^5.6.0", - "element-plus": "^2.10.2", + "element-plus": "^2.11.7", "file-saver": "^2.0.5", "highlight.js": "^11.11.1", "html-to-image": "^1.11.13", "html2canvas": "^1.4.1", - "jspdf": "^3.0.1", + "jspdf": "^3.0.3", "katex": "^0.16.10", "marked": "^12.0.2", "md-editor-v3": "^5.8.2", diff --git a/ui/src/api/application/application.ts b/ui/src/api/application/application.ts index 07ef53689..5502b2929 100644 --- a/ui/src/api/application/application.ts +++ b/ui/src/api/application/application.ts @@ -192,6 +192,26 @@ const getStatistics: ( ) => Promise> = (application_id, data, loading) => { return get(`${prefix.value}/${application_id}/application_stats`, data, loading) } +/** + * 统计token消耗 + */ +const getTokenUsage: ( + application_id: string, + data: any, + loading?: Ref, +) => Promise> = (application_id, data, loading) => { + return get(`${prefix.value}/${application_id}/application_token_usage`, data, loading) +} +/** + * 统计提问次数 + */ +const topQuestions: ( + application_id: string, + data: any, + loading?: Ref, +) => Promise> = (application_id, data, loading) => { + return get(`${prefix.value}/${application_id}/top_questions`, data, loading) +} /** * 打开调试对话id * @param application_id 应用id @@ -207,11 +227,11 @@ const open: (application_id: string, loading?: Ref) => Promise Promise = ( workspace_id, @@ -408,5 +428,7 @@ export default { speechToText, getMcpTools, postUploadFile, - generate_prompt + generate_prompt, + getTokenUsage, + topQuestions } diff --git a/ui/src/api/image.ts b/ui/src/api/image.ts index e4800216d..6bfacd248 100644 --- a/ui/src/api/image.ts +++ b/ui/src/api/image.ts @@ -1,5 +1,5 @@ -import { Result } from '@/request/Result' -import { get, post, del, put } from '@/request/index' +import {Result} from '@/request/Result' +import {get, post, del, put} from '@/request/index' const prefix = '/oss/file' /** @@ -10,6 +10,10 @@ const postImage: (data: any) => Promise> = (data) => { return post(`${prefix}`, data) } +const getFile: (params: any) => Promise> = (params) => { + return get(`/oss/get_url` , params) +} export default { postImage, + getFile } diff --git a/ui/src/api/knowledge/document.ts b/ui/src/api/knowledge/document.ts index 86587b8b1..9afbee6ac 100644 --- a/ui/src/api/knowledge/document.ts +++ b/ui/src/api/knowledge/document.ts @@ -462,7 +462,7 @@ const postTableDocument: ( } /** - * 获得QA模版 + * 获得QA模板 * @param 参数 fileName,type, */ const exportQATemplate: (fileName: string, type: string, loading?: Ref) => void = ( @@ -474,7 +474,7 @@ const exportQATemplate: (fileName: string, type: string, loading?: Ref) } /** - * 获得table模版 + * 获得table模板 * @param 参数 fileName,type, */ const exportTableTemplate: (fileName: string, type: string, loading?: Ref) => void = ( diff --git a/ui/src/api/system-resource-management/application.ts b/ui/src/api/system-resource-management/application.ts index b3ae5a515..6e1b0ef62 100644 --- a/ui/src/api/system-resource-management/application.ts +++ b/ui/src/api/system-resource-management/application.ts @@ -111,6 +111,23 @@ const getStatistics: ( ) => Promise> = (application_id, data, loading) => { return get(`${prefix}/${application_id}/application_stats`, data, loading) } +/** + * 统计token消耗 + */ +const getTokenUsage: ( + application_id: string, + data: any, + loading?: Ref, +) => Promise> = (application_id, data, loading) => { + return get(`${prefix}/${application_id}/application_token_usage`, data, loading) +} +const topQuestions: ( + application_id: string, + data: any, + loading?: Ref, +) => Promise> = (application_id, data, loading) => { + return get(`${prefix}/${application_id}/top_questions`, data, loading) +} /** * 打开调试对话id * @param application_id 应用id @@ -126,10 +143,10 @@ const open: (application_id: string, loading?: Ref) => Promise Promise = ( application_id, @@ -174,7 +191,7 @@ const playDemoText: (application_id: string, data: any, loading?: Ref) * 文本转语音 */ const postTextToSpeech: ( - application_id: String, + application_id: string, data: any, loading?: Ref, ) => Promise> = (application_id, data, loading) => { @@ -184,7 +201,7 @@ const postTextToSpeech: ( * 语音转文本 */ const speechToText: ( - application_id: String, + application_id: string, data: any, loading?: Ref, ) => Promise> = (application_id, data, loading) => { @@ -289,7 +306,7 @@ const updatePlatformConfig: ( /** * mcp 节点 */ -const getMcpTools: (application_id: String, loading?: Ref) => Promise> = ( +const getMcpTools: (application_id: string, loading?: Ref) => Promise> = ( application_id, loading, ) => { @@ -320,5 +337,7 @@ export default { speechToText, getMcpTools, putXpackAccessToken, - generate_prompt + generate_prompt, + getTokenUsage, + topQuestions } diff --git a/ui/src/api/system-resource-management/document.ts b/ui/src/api/system-resource-management/document.ts index c661bc5b8..cc9b235d2 100644 --- a/ui/src/api/system-resource-management/document.ts +++ b/ui/src/api/system-resource-management/document.ts @@ -433,7 +433,7 @@ const postTableDocument: ( } /** - * 获得QA模版 + * 获得QA模板 * @param 参数 fileName,type, */ const exportQATemplate: (fileName: string, type: string, loading?: Ref) => void = ( @@ -445,7 +445,7 @@ const exportQATemplate: (fileName: string, type: string, loading?: Ref) } /** - * 获得table模版 + * 获得table模板 * @param 参数 fileName,type, */ const exportTableTemplate: (fileName: string, type: string, loading?: Ref) => void = ( diff --git a/ui/src/api/system-shared/document.ts b/ui/src/api/system-shared/document.ts index dc5c79fdc..10638938b 100644 --- a/ui/src/api/system-shared/document.ts +++ b/ui/src/api/system-shared/document.ts @@ -433,7 +433,7 @@ const postTableDocument: ( } /** - * 获得QA模版 + * 获得QA模板 * @param 参数 fileName,type, */ const exportQATemplate: (fileName: string, type: string, loading?: Ref) => void = ( @@ -445,7 +445,7 @@ const exportQATemplate: (fileName: string, type: string, loading?: Ref) } /** - * 获得table模版 + * 获得table模板 * @param 参数 fileName,type, */ const exportTableTemplate: (fileName: string, type: string, loading?: Ref) => void = ( diff --git a/ui/src/api/system/role.ts b/ui/src/api/system/role.ts index 18e74cb1e..096df3537 100644 --- a/ui/src/api/system/role.ts +++ b/ui/src/api/system/role.ts @@ -14,7 +14,7 @@ const getRoleList: (loading?: Ref) => Promise) => Promise> = (role_type, loading) => { return get(`${prefix}/template/${role_type}`, undefined, loading) @@ -106,4 +106,4 @@ export default { getRoleMemberList, CreateMember, deleteRoleMember -} \ No newline at end of file +} diff --git a/ui/src/components/ai-chat/component/chat-input-operate/index.vue b/ui/src/components/ai-chat/component/chat-input-operate/index.vue index d45af15a9..098f39189 100644 --- a/ui/src/components/ai-chat/component/chat-input-operate/index.vue +++ b/ui/src/components/ai-chat/component/chat-input-operate/index.vue @@ -4,8 +4,7 @@ {{ $t('chat.operation.stopChat') }} - +
@@ -180,7 +179,8 @@