diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py index ca2d00e0b..fe6be7d2e 100644 --- a/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py +++ b/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py @@ -37,6 +37,8 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep): "最大携带知识库段落长度")) # 模板 prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) + system = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("系统提示词(角色)")) # 补齐问题 padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题")) # 未查询到引用分段 @@ -59,6 +61,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep): prompt: str, padding_problem_text: str = None, no_references_setting=None, + system=None, **kwargs) -> List[BaseMessage]: """ @@ -71,6 +74,7 @@ class IGenerateHumanMessageStep(IBaseChatPipelineStep): :param padding_problem_text 用户修改文本 :param kwargs: 其他参数 :param no_references_setting: 无引用分段设置 + :param system 系统提示称 :return: """ pass diff --git a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py index 8b769c770..68cfbbcb9 100644 --- a/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py +++ b/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py @@ -9,6 +9,7 @@ from typing import List, Dict from langchain.schema import BaseMessage, HumanMessage +from langchain_core.messages import SystemMessage from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \ @@ -27,6 +28,7 @@ class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep): prompt: str, padding_problem_text: str = None, no_references_setting=None, + system=None, **kwargs) -> List[BaseMessage]: prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get( 'value') @@ -35,6 +37,11 @@ class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep): history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] for index in range(start_index if start_index > 0 else 0, len(history_chat_record))] + if system is not None and len(system) > 0: + return [SystemMessage(system), *flat_map(history_message), + self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list, + no_references_setting)] + return [*flat_map(history_message), self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list, no_references_setting)] diff --git a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py index ce30d96af..fe8068119 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py @@ -29,6 +29,8 @@ class IResetProblemStep(IBaseChatPipelineStep): error_messages=ErrMessage.list("历史对答")) # 大语言模型 chat_model = ModelField(required=False, allow_null=True, error_messages=ErrMessage.base("大语言模型")) + problem_optimization_prompt = serializers.CharField(required=False, max_length=102400, + error_messages=ErrMessage.char("问题补全提示词")) def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: return self.InstanceSerializer @@ -47,5 +49,6 @@ class IResetProblemStep(IBaseChatPipelineStep): @abstractmethod def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None, + problem_optimization_prompt=None, **kwargs): pass diff --git a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py index c0595d590..3a32bbf02 100644 --- a/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py +++ b/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py @@ -21,6 +21,7 @@ prompt = ( class BaseResetProblemStep(IResetProblemStep): def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, chat_model: BaseChatModel = None, + problem_optimization_prompt=None, **kwargs) -> str: if chat_model is None: self.context['message_tokens'] = 0 @@ -30,8 +31,9 @@ class BaseResetProblemStep(IResetProblemStep): history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] for index in range(start_index if start_index > 0 else 0, len(history_chat_record))] + reset_prompt = problem_optimization_prompt if problem_optimization_prompt else prompt message_list = [*flat_map(history_message), - HumanMessage(content=prompt.format(**{'question': problem_text}))] + HumanMessage(content=reset_prompt.replace('{question}', problem_text))] response = chat_model.invoke(message_list) padding_problem = problem_text if response.content.__contains__("") and response.content.__contains__(''): @@ -39,6 +41,9 @@ class BaseResetProblemStep(IResetProblemStep): response.content.index('') + 6:response.content.index('')] if padding_problem_data is not None and len(padding_problem_data.strip()) > 0: padding_problem = padding_problem_data + elif len(response.content) > 0: + padding_problem = response.content + try: request_token = chat_model.get_num_tokens_from_messages(message_list) response_token = chat_model.get_num_tokens(padding_problem) diff --git a/apps/application/flow/step_node/start_node/i_start_node.py b/apps/application/flow/step_node/start_node/i_start_node.py index 4c1ecfd2a..bb23ad3f5 100644 --- a/apps/application/flow/step_node/start_node/i_start_node.py +++ b/apps/application/flow/step_node/start_node/i_start_node.py @@ -16,9 +16,6 @@ from application.flow.i_step_node import INode, NodeResult class IStarNode(INode): type = 'start-node' - def get_node_params_serializer_class(self) -> Type[serializers.Serializer] | None: - return None - def _run(self): return self.execute(**self.flow_params_serializer.data) diff --git a/apps/application/flow/step_node/start_node/impl/base_start_node.py b/apps/application/flow/step_node/start_node/impl/base_start_node.py index dc4fb541c..f6528660d 100644 --- a/apps/application/flow/step_node/start_node/impl/base_start_node.py +++ b/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -15,11 +15,16 @@ from application.flow.step_node.start_node.i_start_node import IStarNode class BaseStartStepNode(IStarNode): def execute(self, question, **kwargs) -> NodeResult: + history_chat_record = self.flow_params_serializer.data.get('history_chat_record', []) + history_context = [{'question': chat_record.problem_text, 'answer': chat_record.answer_text} for chat_record in + history_chat_record] + chat_id = self.flow_params_serializer.data.get('chat_id') """ 开始节点 初始化全局变量 """ return NodeResult({'question': question}, - {'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time()}) + {'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time(), + 'history_context': history_context, 'chat_id': str(chat_id)}) def get_details(self, index: int, **kwargs): global_fields = [] diff --git a/apps/application/migrations/0014_application_problem_optimization_prompt.py b/apps/application/migrations/0014_application_problem_optimization_prompt.py new file mode 100644 index 000000000..e2efc1097 --- /dev/null +++ b/apps/application/migrations/0014_application_problem_optimization_prompt.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-09-13 18:57 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0013_application_tts_type'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='problem_optimization_prompt', + field=models.CharField(blank=True, default='()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中', max_length=102400, null=True, verbose_name='问题优化提示词'), + ), + ] diff --git a/apps/application/models/application.py b/apps/application/models/application.py index c03bec1eb..13bf3b6aa 100644 --- a/apps/application/models/application.py +++ b/apps/application/models/application.py @@ -35,7 +35,7 @@ def get_dataset_setting_dict(): def get_model_setting_dict(): - return {'prompt': Application.get_default_model_prompt()} + return {'prompt': Application.get_default_model_prompt(), 'no_references_prompt': '{question}'} class Application(AppModelMixin): @@ -54,8 +54,13 @@ class Application(AppModelMixin): work_flow = models.JSONField(verbose_name="工作流数据", default=dict) type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices.choices, default=ApplicationTypeChoices.SIMPLE, max_length=256) - tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) - stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) + problem_optimization_prompt = models.CharField(verbose_name="问题优化提示词", max_length=102400, blank=True, + null=True, + default="()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中") + tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False, + blank=True, null=True) + stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False, + blank=True, null=True) tts_model_enable = models.BooleanField(verbose_name="语音合成模型是否启用", default=False) stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False) tts_type = models.CharField(verbose_name="语音播放类型", max_length=20, default="BROWSER") diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index a04af8999..35145eee9 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -120,7 +120,12 @@ class DatasetSettingSerializer(serializers.Serializer): class ModelSettingSerializer(serializers.Serializer): - prompt = serializers.CharField(required=True, max_length=2048, error_messages=ErrMessage.char("提示词")) + prompt = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400, + error_messages=ErrMessage.char("提示词")) + system = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400, + error_messages=ErrMessage.char("角色提示词")) + no_references_prompt = serializers.CharField(required=True, max_length=102400, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("无引用分段提示词")) class ApplicationWorkflowSerializer(serializers.Serializer): @@ -174,7 +179,7 @@ class ApplicationSerializer(serializers.Serializer): error_messages=ErrMessage.char("应用描述")) model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, error_messages=ErrMessage.char("模型")) - multiple_rounds_dialogue = serializers.BooleanField(required=True, error_messages=ErrMessage.char("多轮对话")) + dialogue_number = serializers.BooleanField(required=True, error_messages=ErrMessage.char("会话次数")) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, error_messages=ErrMessage.char("开场白")) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), @@ -185,6 +190,8 @@ class ApplicationSerializer(serializers.Serializer): model_setting = ModelSettingSerializer(required=True) # 问题补全 problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全")) + problem_optimization_prompt = serializers.CharField(required=False, max_length=102400, + error_messages=ErrMessage.char("问题补全提示词")) # 应用类型 type = serializers.CharField(required=True, error_messages=ErrMessage.char("应用类型"), validators=[ @@ -364,8 +371,8 @@ class ApplicationSerializer(serializers.Serializer): error_messages=ErrMessage.char("应用描述")) model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, error_messages=ErrMessage.char("模型")) - multiple_rounds_dialogue = serializers.BooleanField(required=False, - error_messages=ErrMessage.boolean("多轮会话")) + dialogue_number = serializers.IntegerField(required=False, + error_messages=ErrMessage.boolean("多轮会话")) prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, error_messages=ErrMessage.char("开场白")) dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), @@ -430,13 +437,14 @@ class ApplicationSerializer(serializers.Serializer): def to_application_model(user_id: str, application: Dict): return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'), prologue=application.get('prologue'), - dialogue_number=3 if application.get('multiple_rounds_dialogue') else 0, + dialogue_number=application.get('dialogue_number', 0), user_id=user_id, model_id=application.get('model_id'), dataset_setting=application.get('dataset_setting'), model_setting=application.get('model_setting'), problem_optimization=application.get('problem_optimization'), type=ApplicationTypeChoices.SIMPLE, model_params_setting=application.get('model_params_setting', {}), + problem_optimization_prompt=application.get('problem_optimization_prompt', None), work_flow={} ) @@ -601,7 +609,8 @@ class ApplicationSerializer(serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) application = QuerySet(Application).filter(id=self.data.get("application_id")).first() - return FunctionLibSerializer.Query(data={'user_id': application.user_id}).list(with_valid=True) + return FunctionLibSerializer.Query(data={'user_id': application.user_id, 'is_active': True}).list( + with_valid=True) def get_function_lib(self, function_lib_id, with_valid=True): if with_valid: diff --git a/apps/application/serializers/chat_message_serializers.py b/apps/application/serializers/chat_message_serializers.py index 8fbf0dbbc..44f31759c 100644 --- a/apps/application/serializers/chat_message_serializers.py +++ b/apps/application/serializers/chat_message_serializers.py @@ -60,6 +60,17 @@ class ChatInfo: self.chat_record_list: List[ChatRecord] = [] self.work_flow_version = work_flow_version + @staticmethod + def get_no_references_setting(dataset_setting, model_setting): + no_references_setting = dataset_setting.get( + 'no_references_setting', { + 'status': 'ai_questioning', + 'value': '{question}'}) + if no_references_setting.get('status') == 'ai_questioning': + no_references_prompt = model_setting.get('no_references_prompt', '{question}') + no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}" + return no_references_setting + def to_base_pipeline_manage_params(self): dataset_setting = self.application.dataset_setting model_setting = self.application.model_setting @@ -80,8 +91,13 @@ class ChatInfo: 'history_chat_record': self.chat_record_list, 'chat_id': self.chat_id, 'dialogue_number': self.application.dialogue_number, + 'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len( + self.application.problem_optimization_prompt) > 0 else '()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中', 'prompt': model_setting.get( - 'prompt') if 'prompt' in model_setting else Application.get_default_model_prompt(), + 'prompt') if 'prompt' in model_setting and len(model_setting.get( + 'prompt')) > 0 else Application.get_default_model_prompt(), + 'system': model_setting.get( + 'system', None), 'model_id': model_id, 'problem_optimization': self.application.problem_optimization, 'stream': True, @@ -89,11 +105,7 @@ class ChatInfo: self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting, 'search_mode': self.application.dataset_setting.get( 'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding', - 'no_references_setting': self.application.dataset_setting.get( - 'no_references_setting') if 'no_references_setting' in self.application.dataset_setting else { - 'status': 'ai_questioning', - 'value': '{question}', - }, + 'no_references_setting': self.get_no_references_setting(self.application.dataset_setting, model_setting), 'user_id': self.application.user_id } diff --git a/apps/application/swagger_api/application_api.py b/apps/application/swagger_api/application_api.py index e153f6279..d05fbb047 100644 --- a/apps/application/swagger_api/application_api.py +++ b/apps/application/swagger_api/application_api.py @@ -40,15 +40,15 @@ class ApplicationApi(ApiMixin): def get_response_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['id', 'name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'user_id', 'status', 'create_time', + required=['id', 'name', 'desc', 'model_id', 'dialogue_number', 'user_id', 'status', 'create_time', 'update_time'], properties={ 'id': openapi.Schema(type=openapi.TYPE_STRING, title="", description="主键id"), 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), - "multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话", - description="是否开启多轮对话"), + "dialogue_number": openapi.Schema(type=openapi.TYPE_NUMBER, title="多轮对话次数", + description="多轮对话次数"), 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), 'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), title="示例列表", description="示例列表"), @@ -164,8 +164,8 @@ class ApplicationApi(ApiMixin): 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), - "multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话", - description="是否开启多轮对话"), + "dialogue_number": openapi.Schema(type=openapi.TYPE_NUMBER, title="多轮对话次数", + description="多轮对话次数"), 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), @@ -176,7 +176,22 @@ class ApplicationApi(ApiMixin): description="是否开启问题优化", default=True), 'icon': openapi.Schema(type=openapi.TYPE_STRING, title="icon", description="icon", default="/ui/favicon.ico"), - 'work_flow': ApplicationApi.WorkFlow.get_request_body_api() + 'type': openapi.Schema(type=openapi.TYPE_STRING, title="应用类型", + description="应用类型 简易:SIMPLE|工作流:WORK_FLOW"), + 'work_flow': ApplicationApi.WorkFlow.get_request_body_api(), + 'problem_optimization_prompt': openapi.Schema(type=openapi.TYPE_STRING, title='问题优化提示词', + description="问题优化提示词", + default="()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中"), + 'tts_model_id': openapi.Schema(type=openapi.TYPE_STRING, title="文字转语音模型ID", + description="文字转语音模型ID"), + 'stt_model_id': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字模型id", + description="语音转文字模型id"), + 'stt_model_enable': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字是否开启", + description="语音转文字是否开启"), + 'tts_model_enable': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字是否开启", + description="语音转文字是否开启"), + 'tts_type': openapi.Schema(type=openapi.TYPE_STRING, title="文字转语音类型", + description="文字转语音类型") } ) @@ -248,6 +263,11 @@ class ApplicationApi(ApiMixin): '\n问题:' '\n{question}')), + 'system': openapi.Schema(type=openapi.TYPE_STRING, title="系统提示词(角色)", + description="系统提示词(角色)"), + 'no_references_prompt': openapi.Schema(type=openapi.TYPE_STRING, title="无引用分段提示词", + default="{question}", description="无引用分段提示词") + } ) @@ -267,14 +287,14 @@ class ApplicationApi(ApiMixin): def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting', - 'problem_optimization'], + required=['name', 'desc', 'model_id', 'dialogue_number', 'dataset_setting', 'model_setting', + 'problem_optimization', 'stt_model_enable', 'stt_model_enable', 'tts_type'], properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), - "multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话", - description="是否开启多轮对话"), + "dialogue_number": openapi.Schema(type=openapi.TYPE_NUMBER, title="多轮对话次数", + description="多轮对话次数"), 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), @@ -284,8 +304,20 @@ class ApplicationApi(ApiMixin): 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化", description="是否开启问题优化", default=True), 'type': openapi.Schema(type=openapi.TYPE_STRING, title="应用类型", - description="应用类型 简易:SIMPLE|工作流:WORK_FLOW") - + description="应用类型 简易:SIMPLE|工作流:WORK_FLOW"), + 'problem_optimization_prompt': openapi.Schema(type=openapi.TYPE_STRING, title='问题优化提示词', + description="问题优化提示词", + default="()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中"), + 'tts_model_id': openapi.Schema(type=openapi.TYPE_STRING, title="文字转语音模型ID", + description="文字转语音模型ID"), + 'stt_model_id': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字模型id", + description="语音转文字模型id"), + 'stt_model_enable': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字是否开启", + description="语音转文字是否开启"), + 'tts_model_enable': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字是否开启", + description="语音转文字是否开启"), + 'tts_type': openapi.Schema(type=openapi.TYPE_STRING, title="文字转语音类型", + description="文字转语音类型") } ) diff --git a/apps/common/handle/impl/table/xls_parse_table_handle.py b/apps/common/handle/impl/table/xls_parse_table_handle.py index 75e59ede1..6c30d49de 100644 --- a/apps/common/handle/impl/table/xls_parse_table_handle.py +++ b/apps/common/handle/impl/table/xls_parse_table_handle.py @@ -19,26 +19,41 @@ class XlsSplitHandle(BaseParseTableHandle): def handle(self, file, get_buffer, save_image): buffer = get_buffer(file) try: - wb = xlrd.open_workbook(file_contents=buffer) + wb = xlrd.open_workbook(file_contents=buffer, formatting_info=True) result = [] sheets = wb.sheets() for sheet in sheets: + # 获取合并单元格的范围信息 + merged_cells = sheet.merged_cells + print(merged_cells) + data = [] paragraphs = [] - rows = iter([sheet.row_values(i) for i in range(sheet.nrows)]) - if not rows: continue - ti = next(rows) - for r in rows: - l = [] - for i, c in enumerate(r): - if not c: - continue - t = str(ti[i]) if i < len(ti) else "" - t += (": " if t else "") + str(c) - l.append(t) - l = "; ".join(l) - if sheet.name.lower().find("sheet") < 0: - l += " ——" + sheet.name - paragraphs.append({'title': '', 'content': l}) + # 获取第一行作为标题行 + headers = [sheet.cell_value(0, col_idx) for col_idx in range(sheet.ncols)] + # 从第二行开始遍历每一行(跳过标题行) + for row_idx in range(1, sheet.nrows): + row_data = {} + for col_idx in range(sheet.ncols): + cell_value = sheet.cell_value(row_idx, col_idx) + + # 检查是否为空单元格,如果为空检查是否在合并区域中 + if cell_value == "": + # 检查当前单元格是否在合并区域 + for (rlo, rhi, clo, chi) in merged_cells: + if rlo <= row_idx < rhi and clo <= col_idx < chi: + # 使用合并区域的左上角单元格的值 + cell_value = sheet.cell_value(rlo, clo) + break + + # 将标题作为键,单元格的值作为值存入字典 + row_data[headers[col_idx]] = cell_value + data.append(row_data) + + for row in data: + row_output = "; ".join([f"{key}: {value}" for key, value in row.items()]) + # print(row_output) + paragraphs.append({'title': '', 'content': row_output}) + result.append({'name': sheet.name, 'paragraphs': paragraphs}) except BaseException as e: diff --git a/apps/common/handle/impl/table/xlsx_parse_table_handle.py b/apps/common/handle/impl/table/xlsx_parse_table_handle.py index c83ef253d..35ef2f14b 100644 --- a/apps/common/handle/impl/table/xlsx_parse_table_handle.py +++ b/apps/common/handle/impl/table/xlsx_parse_table_handle.py @@ -17,6 +17,35 @@ class XlsxSplitHandle(BaseParseTableHandle): return True return False + def fill_merged_cells(self, sheet, image_dict): + data = [] + + # 获取第一行作为标题行 + headers = [cell.value for cell in sheet[1]] + + # 从第二行开始遍历每一行 + for row in sheet.iter_rows(min_row=2, values_only=False): + row_data = {} + for col_idx, cell in enumerate(row): + cell_value = cell.value + + # 如果单元格为空,并且该单元格在合并单元格内,获取合并单元格的值 + if cell_value is None: + for merged_range in sheet.merged_cells.ranges: + if cell.coordinate in merged_range: + cell_value = sheet[merged_range.min_row][merged_range.min_col - 1].value + break + + image = image_dict.get(cell_value, None) + if image is not None: + cell_value = f'' + + # 使用标题作为键,单元格的值作为值存入字典 + row_data[headers[col_idx]] = cell_value + data.append(row_data) + + return data + def handle(self, file, get_buffer, save_image): buffer = get_buffer(file) try: @@ -30,25 +59,13 @@ class XlsxSplitHandle(BaseParseTableHandle): for sheetname in wb.sheetnames: paragraphs = [] ws = wb[sheetname] - rows = list(ws.rows) - if not rows: continue - ti = list(rows[0]) - for r in list(rows[1:]): - l = [] - for i, c in enumerate(r): - if not c.value: - continue - t = str(ti[i].value) if i < len(ti) else "" - content = str(c.value) - image = image_dict.get(content, None) - if image is not None: - content = f'' - t += (": " if t else "") + content - l.append(t) - l = "; ".join(l) - if sheetname.lower().find("sheet") < 0: - l += " ——" + sheetname - paragraphs.append({'title': '', 'content': l}) + data = self.fill_merged_cells(ws, image_dict) + + for row in data: + row_output = "; ".join([f"{key}: {value}" for key, value in row.items()]) + # print(row_output) + paragraphs.append({'title': '', 'content': row_output}) + result.append({'name': sheetname, 'paragraphs': paragraphs}) except BaseException as e: diff --git a/apps/dataset/urls.py b/apps/dataset/urls.py index 2068922ee..405101796 100644 --- a/apps/dataset/urls.py +++ b/apps/dataset/urls.py @@ -23,6 +23,7 @@ urlpatterns = [ path('dataset//document/_bach', views.Document.Batch.as_view()), path('dataset//document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()), path('dataset//document//', views.Document.Page.as_view()), + path('dataset//document/batch_refresh', views.Document.BatchRefresh.as_view()), path('dataset//document/', views.Document.Operate.as_view(), name="document_operate"), path('dataset/document/split', views.Document.Split.as_view(), @@ -34,7 +35,6 @@ urlpatterns = [ name="document_export"), path('dataset//document//sync', views.Document.SyncWeb.as_view()), path('dataset//document//refresh', views.Document.Refresh.as_view()), - path('dataset//document/batch_refresh', views.Document.BatchRefresh.as_view()), path('dataset//document//paragraph', views.Paragraph.as_view()), path( 'dataset//document//paragraph/migrate/dataset//document/', diff --git a/apps/dataset/views/document.py b/apps/dataset/views/document.py index c2ef152a0..d41535b0b 100644 --- a/apps/dataset/views/document.py +++ b/apps/dataset/views/document.py @@ -239,7 +239,7 @@ class Document(APIView): class BatchRefresh(APIView): authentication_classes = [TokenAuth] - @action(methods=['POST'], detail=False) + @action(methods=['PUT'], detail=False) @swagger_auto_schema(operation_summary="批量刷新文档向量库", operation_id="批量刷新文档向量库", request_body= diff --git a/apps/function_lib/migrations/0002_functionlib_is_active_functionlib_permission_type.py b/apps/function_lib/migrations/0002_functionlib_is_active_functionlib_permission_type.py new file mode 100644 index 000000000..c665ef22a --- /dev/null +++ b/apps/function_lib/migrations/0002_functionlib_is_active_functionlib_permission_type.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.15 on 2024-09-14 11:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('function_lib', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='functionlib', + name='is_active', + field=models.BooleanField(default=True), + ), + migrations.AddField( + model_name='functionlib', + name='permission_type', + field=models.CharField(choices=[('PUBLIC', '公开'), ('PRIVATE', '私有')], default='PRIVATE', max_length=20, verbose_name='权限类型'), + ), + ] diff --git a/apps/function_lib/models/function.py b/apps/function_lib/models/function.py index d41c6e9bb..49a0e981b 100644 --- a/apps/function_lib/models/function.py +++ b/apps/function_lib/models/function.py @@ -15,6 +15,11 @@ from common.mixins.app_model_mixin import AppModelMixin from users.models import User +class PermissionType(models.TextChoices): + PUBLIC = "PUBLIC", '公开' + PRIVATE = "PRIVATE", "私有" + + class FunctionLib(AppModelMixin): id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户id") @@ -24,6 +29,9 @@ class FunctionLib(AppModelMixin): input_field_list = ArrayField(verbose_name="输入字段列表", base_field=models.JSONField(verbose_name="输入字段", default=dict) , default=list) + is_active = models.BooleanField(default=True) + permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices, + default=PermissionType.PRIVATE) class Meta: db_table = "function_lib" diff --git a/apps/function_lib/serializers/function_lib_serializer.py b/apps/function_lib/serializers/function_lib_serializer.py index 468b07bd2..99866d7bf 100644 --- a/apps/function_lib/serializers/function_lib_serializer.py +++ b/apps/function_lib/serializers/function_lib_serializer.py @@ -11,7 +11,7 @@ import re import uuid from django.core import validators -from django.db.models import QuerySet +from django.db.models import QuerySet, Q from rest_framework import serializers from common.db.search import page_search @@ -27,7 +27,7 @@ function_executor = FunctionExecutor(CONFIG.get('SANDBOX')) class FunctionLibModelSerializer(serializers.ModelSerializer): class Meta: model = FunctionLib - fields = ['id', 'name', 'desc', 'code', 'input_field_list', + fields = ['id', 'name', 'desc', 'code', 'input_field_list', 'permission_type', 'is_active', 'create_time', 'update_time'] @@ -68,6 +68,8 @@ class EditFunctionLib(serializers.Serializer): input_field_list = FunctionLibInputField(required=False, many=True) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char('是否可用')) + class CreateFunctionLib(serializers.Serializer): name = serializers.CharField(required=True, error_messages=ErrMessage.char("函数名称")) @@ -79,6 +81,12 @@ class CreateFunctionLib(serializers.Serializer): input_field_list = FunctionLibInputField(required=True, many=True) + permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[ + validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"), + message="权限只支持PUBLIC|PRIVATE", code=500) + ]) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char('是否可用')) + class FunctionLibSerializer(serializers.Serializer): class Query(serializers.Serializer): @@ -87,15 +95,19 @@ class FunctionLibSerializer(serializers.Serializer): desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, error_messages=ErrMessage.char("函数描述")) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char("是否可用")) user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) def get_query_set(self): - query_set = QuerySet(FunctionLib).filter(user_id=self.data.get('user_id')) + query_set = QuerySet(FunctionLib).filter( + (Q(user_id=self.data.get('user_id')) | Q(permission_type='PUBLIC'))) if self.data.get('name') is not None: query_set = query_set.filter(name__contains=self.data.get('name')) if self.data.get('desc') is not None: query_set = query_set.filter(desc__contains=self.data.get('desc')) + if self.data.get('is_active') is not None: + query_set = query_set.filter(is_active=self.data.get('is_active')) query_set = query_set.order_by("-create_time") return query_set @@ -120,7 +132,9 @@ class FunctionLibSerializer(serializers.Serializer): function_lib = FunctionLib(id=uuid.uuid1(), name=instance.get('name'), desc=instance.get('desc'), code=instance.get('code'), user_id=self.data.get('user_id'), - input_field_list=instance.get('input_field_list')) + input_field_list=instance.get('input_field_list'), + permission_type=instance.get('permission_type'), + is_active=instance.get('is_active', True)) function_lib.save() return FunctionLibModelSerializer(function_lib).data @@ -193,7 +207,7 @@ class FunctionLibSerializer(serializers.Serializer): if with_valid: self.is_valid(raise_exception=True) EditFunctionLib(data=instance).is_valid(raise_exception=True) - edit_field_list = ['name', 'desc', 'code', 'input_field_list'] + edit_field_list = ['name', 'desc', 'code', 'input_field_list', 'permission_type', 'is_active'] edit_dict = {field: instance.get(field) for field in edit_field_list if ( field in instance and instance.get(field) is not None)} QuerySet(FunctionLib).filter(id=self.data.get('id')).update(**edit_dict) diff --git a/apps/function_lib/swagger_api/function_lib_api.py b/apps/function_lib/swagger_api/function_lib_api.py index ce396c6da..9ab7f7cd3 100644 --- a/apps/function_lib/swagger_api/function_lib_api.py +++ b/apps/function_lib/swagger_api/function_lib_api.py @@ -103,6 +103,8 @@ class FunctionLibApi(ApiMixin): 'name': openapi.Schema(type=openapi.TYPE_STRING, title="函数名称", description="函数名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="函数描述", description="函数描述"), 'code': openapi.Schema(type=openapi.TYPE_STRING, title="函数内容", description="函数内容"), + 'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限", description="权限"), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), 'input_field_list': openapi.Schema(type=openapi.TYPE_ARRAY, description="输入变量列表", items=openapi.Schema(type=openapi.TYPE_OBJECT, @@ -135,11 +137,13 @@ class FunctionLibApi(ApiMixin): def get_request_body_api(): return openapi.Schema( type=openapi.TYPE_OBJECT, - required=['name', 'code', 'input_field_list'], + required=['name', 'code', 'input_field_list', 'permission_type'], properties={ 'name': openapi.Schema(type=openapi.TYPE_STRING, title="函数名称", description="函数名称"), 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="函数描述", description="函数描述"), 'code': openapi.Schema(type=openapi.TYPE_STRING, title="函数内容", description="函数内容"), + 'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限", description="权限"), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), 'input_field_list': openapi.Schema(type=openapi.TYPE_ARRAY, description="输入变量列表", items=openapi.Schema(type=openapi.TYPE_OBJECT, diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py index 2392d3d04..e64d8b282 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py @@ -19,7 +19,7 @@ class BedrockLLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=1024, _min=1, _max=4096, diff --git a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py index 1906d20c5..b9e730aa0 100644 --- a/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class AzureLLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=800, _min=1, _max=4096, diff --git a/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py b/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py index 72f101c42..ee2279bbc 100644 --- a/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class DeepSeekLLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=800, _min=1, _max=4096, diff --git a/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py b/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py index 77e01df39..2612205d4 100644 --- a/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class GeminiLLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=800, _min=1, _max=4096, diff --git a/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py b/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py index ef2aeb122..1ee2fcee1 100644 --- a/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class KimiLLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=1024, _min=1, _max=4096, diff --git a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py index 14634b478..5558bcab4 100644 --- a/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py @@ -23,7 +23,7 @@ class OllamaLLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=1024, _min=1, _max=4096, diff --git a/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py b/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py index f7d244a53..58dfc1308 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class OpenAILLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=800, _min=1, _max=4096, diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index ff80b0e50..a78ecc0c0 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -40,8 +40,6 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): openai_api_base=model_credential.get('api_base'), openai_api_key=model_credential.get('api_key'), **optional_params, - streaming=True, - stream_usage=True, custom_get_token_ids=custom_get_token_ids ) return azure_chat_open_ai diff --git a/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py index 7ad068454..a8177c545 100644 --- a/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class QwenModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=800, _min=1, _max=2048, diff --git a/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py index 7ea3cabe4..97c6217c3 100644 --- a/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py @@ -19,7 +19,7 @@ class VLLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=800, _min=1, _max=4096, diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py index f918b437d..15fffec2c 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class VolcanicEngineLLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=1024, _min=1, _max=4096, diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/iat_mp3_16k.mp3 b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/iat_mp3_16k.mp3 new file mode 100644 index 000000000..75e744c8f Binary files /dev/null and b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/iat_mp3_16k.mp3 differ diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py index 36aed6426..d416a7586 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py @@ -11,6 +11,7 @@ import base64 import gzip import hmac import json +import os import uuid import wave from enum import Enum @@ -144,6 +145,7 @@ def parse_response(res): result['code'] = code payload_size = int.from_bytes(payload[4:8], "big", signed=False) payload_msg = payload[8:] + print(f"Error code: {code}, message: {payload_msg}") if payload_msg is None: return result if message_compression == GZIP: @@ -321,14 +323,9 @@ class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText): return result['payload_msg']['result'][0]['text'] def check_auth(self): - header = self.token_auth() - - async def check(): - async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000, - ssl=ssl_context) as ws: - pass - - asyncio.run(check()) + cwd = os.path.dirname(os.path.abspath(__file__)) + with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f: + self.speech_to_text(f) def speech_to_text(self, file): data = file.read() diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py index 3a5e0afb3..71e022015 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py @@ -69,14 +69,7 @@ class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): ) def check_auth(self): - header = self.token_auth() - - async def check(): - async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None, - ssl=ssl_context) as ws: - pass - - asyncio.run(check()) + self.text_to_speech('你好') def text_to_speech(self, text): request_json = { @@ -159,7 +152,7 @@ class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): if message_compression == 1: error_msg = gzip.decompress(error_msg) error_msg = str(error_msg, "utf-8") - break + raise Exception(f"Error code: {code}, message: {error_msg}") elif message_type == 0xc: msg_size = int.from_bytes(payload[:4], "big", signed=False) payload = payload[4:] diff --git a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py index be294e81a..a77a6303f 100644 --- a/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class WenxinLLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=1024, _min=2, _max=2048, diff --git a/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py b/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py index 770aff27d..0a6d9a0ac 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class XunFeiLLMModelGeneralParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=4096, _min=1, _max=4096, @@ -42,7 +42,7 @@ class XunFeiLLMModelProParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=4096, _min=1, _max=8192, diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/iat_mp3_16k.mp3 b/apps/setting/models_provider/impl/xf_model_provider/model/iat_mp3_16k.mp3 new file mode 100644 index 000000000..75e744c8f Binary files /dev/null and b/apps/setting/models_provider/impl/xf_model_provider/model/iat_mp3_16k.mp3 differ diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py index f57e6bf13..f400473ed 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/stt.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/stt.py @@ -8,6 +8,8 @@ import datetime import hashlib import hmac import json +import logging +import os from datetime import datetime from typing import Dict from urllib.parse import urlencode, urlparse @@ -25,6 +27,7 @@ ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE +max_kb = logging.getLogger("max_kb") class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): spark_app_id: str @@ -89,11 +92,9 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): return url def check_auth(self): - async def check(): - async with websockets.connect(self.create_url(), ssl=ssl_context) as ws: - pass - - asyncio.run(check()) + cwd = os.path.dirname(os.path.abspath(__file__)) + with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f: + self.speech_to_text(f) def speech_to_text(self, file): async def handle(): @@ -112,8 +113,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): sid = message["sid"] if code != 0: errMsg = message["message"] - print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) - return errMsg + raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}") else: data = message["data"]["result"]["ws"] result = "" diff --git a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py index 8d7755832..004b78858 100644 --- a/apps/setting/models_provider/impl/xf_model_provider/model/tts.py +++ b/apps/setting/models_provider/impl/xf_model_provider/model/tts.py @@ -10,6 +10,7 @@ import datetime import hashlib import hmac import json +import logging import os from datetime import datetime from typing import Dict @@ -20,6 +21,8 @@ import websockets from setting.models_provider.base_model_provider import MaxKBBaseModel from setting.models_provider.impl.base_tts import BaseTextToSpeech +max_kb = logging.getLogger("max_kb") + STATUS_FIRST_FRAME = 0 # 第一帧的标识 STATUS_CONTINUE_FRAME = 1 # 中间帧标识 STATUS_LAST_FRAME = 2 # 最后一帧的标识 @@ -92,11 +95,7 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): return url def check_auth(self): - async def check(): - async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: - pass - - asyncio.run(check()) + self.text_to_speech("你好") def text_to_speech(self, text): @@ -119,13 +118,13 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): # print(message) code = message["code"] sid = message["sid"] - audio = message["data"]["audio"] - audio = base64.b64decode(audio) if code != 0: errMsg = message["message"] - print("sid:%s call error:%s code is:%s" % (sid, errMsg, code)) + raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}") else: + audio = message["data"]["audio"] + audio = base64.b64decode(audio) audio_bytes += audio # 退出 if message["data"]["status"] == 2: diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py index bb17e5c22..6317ff663 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py @@ -19,7 +19,7 @@ class XinferenceLLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=800, _min=1, _max=4096, diff --git a/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py b/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py index aee7441f1..dc1d1f191 100644 --- a/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py +++ b/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py @@ -25,7 +25,7 @@ class ZhiPuLLMModelParams(BaseForm): precision=2) max_tokens = forms.SliderField( - TooltipLabel('输出最大Tokens', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), required=True, default_value=1024, _min=1, _max=4096, diff --git a/apps/users/serializers/user_serializers.py b/apps/users/serializers/user_serializers.py index bc18f97af..4eb83bc22 100644 --- a/apps/users/serializers/user_serializers.py +++ b/apps/users/serializers/user_serializers.py @@ -132,8 +132,8 @@ class RegisterSerializer(ApiMixin, serializers.Serializer): max_length=20, min_length=6, validators=[ - validators.RegexValidator(regex=re.compile("^[a-zA-Z][a-zA-Z0-9_]{5,20}$"), - message="用户名字符数为 6-20 个字符,必须以字母开头,可使用字母、数字、下划线等") + validators.RegexValidator(regex=re.compile("^.{6,20}$"), + message="用户名字符数为 6-20 个字符") ]) password = serializers.CharField(required=True, error_messages=ErrMessage.char("密码"), validators=[validators.RegexValidator(regex=re.compile( @@ -590,8 +590,8 @@ class UserManageSerializer(serializers.Serializer): max_length=20, min_length=6, validators=[ - validators.RegexValidator(regex=re.compile("^[a-zA-Z][a-zA-Z0-9_]{5,20}$"), - message="用户名字符数为 6-20 个字符,必须以字母开头,可使用字母、数字、下划线等") + validators.RegexValidator(regex=re.compile("^.{6,20}$"), + message="用户名字符数为 6-20 个字符") ]) password = serializers.CharField(required=True, error_messages=ErrMessage.char("密码"), validators=[validators.RegexValidator(regex=re.compile( diff --git a/pyproject.toml b/pyproject.toml index d7fce17fd..785dae1cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ django = "4.2.15" djangorestframework = "^3.15.2" drf-yasg = "1.21.7" django-filter = "23.2" -langchain = "0.2.3" +langchain = "0.2.16" langchain_community = "0.2.4" langchain-huggingface = "^0.0.3" psycopg2-binary = "2.9.7" diff --git a/ui/src/api/document.ts b/ui/src/api/document.ts index 0653f2d40..d4d9b505b 100644 --- a/ui/src/api/document.ts +++ b/ui/src/api/document.ts @@ -3,6 +3,7 @@ import { get, post, del, put, exportExcel } from '@/request/index' import type { Ref } from 'vue' import type { KeyValue } from '@/api/type/common' import type { pageRequest } from '@/api/type/common' + const prefix = '/dataset' /** @@ -26,14 +27,14 @@ const listSplitPattern: ( /** * 文档分页列表 - * @param 参数 dataset_id, + * @param 参数 dataset_id, * page { - "current_page": "string", - "page_size": "string", - } -* param { - "name": "string", - } + "current_page": "string", + "page_size": "string", + } + * param { + "name": "string", + } */ const getDocument: ( @@ -58,22 +59,22 @@ const getAllDocument: (dataset_id: string, loading?: Ref) => Promise Promise> = (dataset_id, data, loading) => { return del(`${prefix}/${dataset_id}/document/_bach`, undefined, { id_list: data }, loading) } + +const batchRefresh: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return put( + `${prefix}/${dataset_id}/document/batch_refresh`, + { id_list: data }, + undefined, + loading + ) +} /** * 文档详情 * @param 参数 dataset_id @@ -180,14 +194,14 @@ const delMulSyncDocument: ( /** * 创建Web站点文档 - * @param 参数 + * @param 参数 * { - "source_url_list": [ - "string" - ], - "selector": "string" + "source_url_list": [ + "string" + ], + "selector": "string" + } } -} */ const postWebDocument: ( dataset_id: string, @@ -199,9 +213,9 @@ const postWebDocument: ( /** * 导入QA文档 - * @param 参数 + * @param 参数 * file -} + } */ const postQADocument: ( dataset_id: string, @@ -323,5 +337,6 @@ export default { exportTableTemplate, postQADocument, postTableDocument, - exportDocument + exportDocument, + batchRefresh } diff --git a/ui/src/api/type/application.ts b/ui/src/api/type/application.ts index 24215a5fb..97be3b257 100644 --- a/ui/src/api/type/application.ts +++ b/ui/src/api/type/application.ts @@ -4,12 +4,13 @@ interface ApplicationFormType { name?: string desc?: string model_id?: string - multiple_rounds_dialogue?: boolean + dialogue_number?: number prologue?: string dataset_id_list?: string[] dataset_setting?: any model_setting?: any problem_optimization?: boolean + problem_optimization_prompt?: string icon?: string | undefined type?: string work_flow?: any diff --git a/ui/src/api/type/function-lib.ts b/ui/src/api/type/function-lib.ts index ff58d0a7a..2c5efe254 100644 --- a/ui/src/api/type/function-lib.ts +++ b/ui/src/api/type/function-lib.ts @@ -1,9 +1,11 @@ interface functionLibData { id?: String - name: String - desc: String + name?: String + desc?: String code?: String + permission_type?: 'PRIVATE' | 'PUBLIC' input_field_list?: Array + is_active?: Boolean } export type { functionLibData } diff --git a/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/ui/src/components/ai-chat/ExecutionDetailDialog.vue index 715108b1c..a6c4dd48b 100644 --- a/ui/src/components/ai-chat/ExecutionDetailDialog.vue +++ b/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -170,7 +170,7 @@ - 检索内容 + 重排内容 - 检索结果 + 重排结果 diff --git a/ui/src/components/ai-chat/index.vue b/ui/src/components/ai-chat/index.vue index 9e62a4b4c..3784368d4 100644 --- a/ui/src/components/ai-chat/index.vue +++ b/ui/src/components/ai-chat/index.vue @@ -812,7 +812,7 @@ const startRecording = async () => { mediaRecorder.value = new Recorder({ type: 'mp3', bitRate: 128, - sampleRate: 44100 + sampleRate: 16000 }) mediaRecorder.value.open( diff --git a/ui/src/locales/lang/zh_CN/views/application.ts b/ui/src/locales/lang/zh_CN/views/application.ts index 8d0888a15..555494327 100644 --- a/ui/src/locales/lang/zh_CN/views/application.ts +++ b/ui/src/locales/lang/zh_CN/views/application.ts @@ -104,8 +104,10 @@ export default { } }, prompt: { - defaultPrompt: - '已知信息:\n{data}\n回答要求:\n- 请使用简洁且专业的语言来回答用户的问题。\n- 如果你不知道答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。\n- 避免提及你是从已知信息中获得的知识。\n- 请保证答案与已知信息中描述的一致。\n- 请使用 Markdown 语法优化答案的格式。\n- 已知信息中的图片、链接地址和脚本语言请直接返回。\n- 请使用与问题相同的语言来回答。\n问题:\n{question}', + defaultPrompt: `已知信息:{data} +用户问题:{question} +回答要求: + - 请使用中文回答用户问题`, defaultPrologue: '您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?' } diff --git a/ui/src/styles/app.scss b/ui/src/styles/app.scss index 2a6af0b4a..1ba986e5b 100644 --- a/ui/src/styles/app.scss +++ b/ui/src/styles/app.scss @@ -733,3 +733,13 @@ h5 { display: none !important; } } + + +.edit-avatar { + position: relative; + .edit-mask { + position: absolute; + left: 0; + background: rgba(0, 0, 0, 0.4); + } +} \ No newline at end of file diff --git a/ui/src/views/application-overview/index.vue b/ui/src/views/application-overview/index.vue index ae4724efb..b08d3cd1a 100644 --- a/ui/src/views/application-overview/index.vue +++ b/ui/src/views/application-overview/index.vue @@ -332,14 +332,5 @@ onMounted(() => { right: 16px; top: 21px; } - - .edit-avatar { - position: relative; - .edit-mask { - position: absolute; - left: 0; - background: rgba(0, 0, 0, 0.4); - } - } } diff --git a/ui/src/views/application/ApplicationSetting.vue b/ui/src/views/application/ApplicationSetting.vue index 9ef5cef2c..ad950feb9 100644 --- a/ui/src/views/application/ApplicationSetting.vue +++ b/ui/src/views/application/ApplicationSetting.vue @@ -61,10 +61,7 @@ /> - + {{ $t('views.application.applicationForm.form.aiModel.label') }} @@ -151,47 +148,51 @@ + + + {{ $t('views.application.applicationForm.form.prompt.label') }} - * + (无引用知识库) + * + - - {{ - $t('views.application.applicationForm.form.prompt.tooltip', { - data: '{data}', - question: '{question}' - }) - }} - - - - - + + + + + + + + {{ $t('views.application.applicationForm.form.prompt.label') }} + (引用知识库) + * + + + + + + - - - - {{ - $t('views.application.applicationForm.form.problemOptimization.label') - }} - - - - - - - + @@ -305,6 +316,7 @@ - - 浏览器播放(免费) - TTS模型 + + 浏览器播放(免费) + TTS模型 - + + + + - {{ applicationForm?.name || $t('views.application.applicationForm.form.appName.label') @@ -494,6 +521,7 @@ @change="openCreateModel($event)" > + - + diff --git a/ui/src/views/user-manage/index.vue b/ui/src/views/user-manage/index.vue index e8b226357..1bff86190 100644 --- a/ui/src/views/user-manage/index.vue +++ b/ui/src/views/user-manage/index.vue @@ -140,14 +140,22 @@ function createUser() { title.value = '创建用户' UserDialogRef.value.open() } else { - common.asyncGetValid(ValidType.User, ValidCount.User, loading).then((res: any) => { - if (res?.data) { - title.value = '创建用户' - UserDialogRef.value.open() - } else { - MsgAlert('提示', '社区版最多支持 2 个用户,如需拥有更多用户,请升级为专业版。') - } + MsgConfirm(`提示`, '社区版最多支持 2 个用户,如需拥有更多用户,请升级为专业版。', { + cancelButtonText: '确定', + confirmButtonText: '购买专业版', + confirmButtonClass: 'primary' }) + .then(() => { + window.open('https://maxkb.cn/pricing.html', '_blank') + }) + .catch(() => { + common.asyncGetValid(ValidType.User, ValidCount.User, loading).then(async (res: any) => { + if (res?.data) { + title.value = '创建用户' + UserDialogRef.value.open() + } + }) + }) } } diff --git a/ui/src/workflow/common/NodeContainer.vue b/ui/src/workflow/common/NodeContainer.vue index 77f1bdf7f..9c4818a17 100644 --- a/ui/src/workflow/common/NodeContainer.vue +++ b/ui/src/workflow/common/NodeContainer.vue @@ -6,7 +6,7 @@ style="overflow: visible" > - + + + + @@ -46,42 +51,44 @@ - - - - - - 参数输出 - - - {{ item.label }} {{ '{' + item.value + '}' }} - + + + + + 参数输出 + + - - - - - + {{ item.label }} {{ '{' + item.value + '}' }} + + + + + + + - - + + + (false) const anchorData = ref() +const showNode = ref(true) const node_status = computed(() => { if (props.nodeModel.properties.status) { return props.nodeModel.properties.status @@ -240,6 +248,9 @@ onMounted(() => { border: 1px solid #f54a45 !important; } } + .arrow-icon { + transition: 0.2s; + } } :deep(.el-card) { overflow: visible; diff --git a/ui/src/workflow/nodes/base-node/index.vue b/ui/src/workflow/nodes/base-node/index.vue index f90550a15..f0d7ddb17 100644 --- a/ui/src/workflow/nodes/base-node/index.vue +++ b/ui/src/workflow/nodes/base-node/index.vue @@ -118,6 +118,7 @@ - + { diff --git a/ui/src/workflow/nodes/reranker-node/index.vue b/ui/src/workflow/nodes/reranker-node/index.vue index 08d4be5f5..4aaaef270 100644 --- a/ui/src/workflow/nodes/reranker-node/index.vue +++ b/ui/src/workflow/nodes/reranker-node/index.vue @@ -22,6 +22,7 @@ :gutter="8" style="margin-bottom: 8px" v-for="(reranker_reference, index) in form_data.reranker_reference_list" + :key="index" > >([]) const modelOptions = ref(null) const openParamSettingDialog = () => { - ParamSettingDialogRef.value?.open(form_data.value.dataset_setting, 'WORK_FLOW') + ParamSettingDialogRef.value?.open(form_data.value, 'WORK_FLOW') } const deleteCondition = (index: number) => { const list = cloneDeep(props.nodeModel.properties.node_data.reranker_reference_list) @@ -242,7 +243,7 @@ const form_data = computed({ } }) function refreshParam(data: any) { - set(props.nodeModel.properties.node_data, 'reranker_setting', data) + set(props.nodeModel.properties.node_data, 'reranker_setting', data.dataset_setting) } function getModel() { if (id) { diff --git a/ui/src/workflow/nodes/search-dataset-node/index.vue b/ui/src/workflow/nodes/search-dataset-node/index.vue index c065668a4..78d1a4f0f 100644 --- a/ui/src/workflow/nodes/search-dataset-node/index.vue +++ b/ui/src/workflow/nodes/search-dataset-node/index.vue @@ -159,11 +159,11 @@ const datasetList = ref([]) const datasetLoading = ref(false) function refreshParam(data: any) { - set(props.nodeModel.properties.node_data, 'dataset_setting', data) + set(props.nodeModel.properties.node_data, 'dataset_setting', data.dataset_setting) } const openParamSettingDialog = () => { - ParamSettingDialogRef.value?.open(form_data.value.dataset_setting, 'WORK_FLOW') + ParamSettingDialogRef.value?.open(form_data.value, 'WORK_FLOW') } function removeDataset(id: any) { diff --git a/ui/src/workflow/nodes/start-node/index.vue b/ui/src/workflow/nodes/start-node/index.vue index 5dd8835db..e50a109a4 100644 --- a/ui/src/workflow/nodes/start-node/index.vue +++ b/ui/src/workflow/nodes/start-node/index.vue @@ -2,25 +2,18 @@ 全局变量 - 当前时间 {time} + {{ item.label }} {{ '{' + item.value + '}' }} - - - - - - - {{ item.name }} {{ '{' + item.variable + '}' }} - - + @@ -28,34 +21,36 @@