From 121614fb81c0a72387974f3ebcb86cbf756a84af Mon Sep 17 00:00:00 2001 From: CaptainB Date: Fri, 5 Sep 2025 17:48:40 +0800 Subject: [PATCH] chore: add model_params_setting to document, knowledge, and paragraph processing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --story=1018694 --user=刘瑞斌 【菲尼克斯】知识库生成问题选择模型,希望可以设置模型参数 https://www.tapd.cn/62980211/s/1768601 --- apps/knowledge/serializers/document.py | 3 +- apps/knowledge/serializers/knowledge.py | 3 +- apps/knowledge/serializers/paragraph.py | 3 +- apps/knowledge/task/generate.py | 16 ++++----- .../generate-related-dialog/index.vue | 34 ++++++++++++++++++- ui/src/stores/modules/prompt.ts | 3 +- 6 files changed, 49 insertions(+), 13 deletions(-) diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py index fa28ead54..a2270a73b 100644 --- a/apps/knowledge/serializers/document.py +++ b/apps/knowledge/serializers/document.py @@ -1305,6 +1305,7 @@ class DocumentSerializers(serializers.Serializer): document_id_list = instance.get("document_id_list") model_id = instance.get("model_id") prompt = instance.get("prompt") + model_params_setting = instance.get("model_params_setting") state_list = instance.get('state_list') ListenerManagement.update_status( QuerySet(Document).filter(id__in=document_id_list), @@ -1327,7 +1328,7 @@ class DocumentSerializers(serializers.Serializer): QuerySet(Document).filter(id__in=document_id_list))() try: for document_id in document_id_list: - generate_related_by_document_id.delay(document_id, model_id, prompt, state_list) + generate_related_by_document_id.delay(document_id, model_id, model_params_setting, prompt, state_list) except AlreadyQueued as e: pass diff --git a/apps/knowledge/serializers/knowledge.py b/apps/knowledge/serializers/knowledge.py index 8da608253..266b10f4a 100644 --- a/apps/knowledge/serializers/knowledge.py +++ b/apps/knowledge/serializers/knowledge.py @@ -267,6 +267,7 @@ class KnowledgeSerializer(serializers.Serializer): knowledge_id = self.data.get('knowledge_id') model_id = instance.get("model_id") prompt = instance.get("prompt") + model_params_setting = instance.get("model_params_setting") state_list = instance.get('state_list') ListenerManagement.update_status( QuerySet(Document).filter(knowledge_id=knowledge_id), @@ -285,7 +286,7 @@ class KnowledgeSerializer(serializers.Serializer): ) ListenerManagement.get_aggregation_document_status_by_knowledge_id(knowledge_id)() try: - generate_related_by_knowledge_id.delay(knowledge_id, model_id, prompt, state_list) + generate_related_by_knowledge_id.delay(knowledge_id, model_id, model_params_setting, prompt, state_list) except AlreadyQueued as e: raise AppApiException(500, _('Failed to send the vectorization task, please try again later!')) diff --git a/apps/knowledge/serializers/paragraph.py b/apps/knowledge/serializers/paragraph.py index dcc09f978..28f1aa10b 100644 --- a/apps/knowledge/serializers/paragraph.py +++ b/apps/knowledge/serializers/paragraph.py @@ -480,6 +480,7 @@ class ParagraphSerializers(serializers.Serializer): paragraph_id_list = instance.get("paragraph_id_list") model_id = instance.get("model_id") prompt = instance.get("prompt") + model_params_setting = instance.get("model_params_setting") document_id = self.data.get('document_id') ListenerManagement.update_status( QuerySet(Document).filter(id=document_id), @@ -493,7 +494,7 @@ class ParagraphSerializers(serializers.Serializer): ) ListenerManagement.get_aggregation_document_status(document_id)() try: - generate_related_by_paragraph_id_list.delay(document_id, paragraph_id_list, model_id, prompt) + generate_related_by_paragraph_id_list.delay(document_id, paragraph_id_list, model_id, model_params_setting, prompt) except AlreadyQueued as e: raise AppApiException(500, _('The task is being executed, please do not send it again.')) diff --git a/apps/knowledge/task/generate.py b/apps/knowledge/task/generate.py index e69a2ea85..9a86a7ea2 100644 --- a/apps/knowledge/task/generate.py +++ b/apps/knowledge/task/generate.py @@ -18,9 +18,9 @@ from models_provider.tools import get_model from ops import celery_app -def get_llm_model(model_id): +def get_llm_model(model_id, model_params_setting=None): model = QuerySet(Model).filter(id=model_id).first() - return ModelManage.get_model(model_id, lambda _id: get_model(model)) + return ModelManage.get_model(model_id, lambda _id: get_model(model, **(model_params_setting or {}))) def generate_problem_by_paragraph(paragraph, llm_model, prompt): @@ -64,18 +64,18 @@ def get_is_the_task_interrupted(document_id): @celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:generate_related_by_knowledge') -def generate_related_by_knowledge_id(knowledge_id, model_id, prompt, state_list=None): +def generate_related_by_knowledge_id(knowledge_id, model_id, model_params_setting, prompt, state_list=None): document_list = QuerySet(Document).filter(knowledge_id=knowledge_id) for document in document_list: try: - generate_related_by_document_id.delay(document.id, model_id, prompt, state_list) + generate_related_by_document_id.delay(document.id, model_id, model_params_setting, prompt, state_list) except Exception as e: pass @celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:generate_related_by_document') -def generate_related_by_document_id(document_id, model_id, prompt, state_list=None): +def generate_related_by_document_id(document_id, model_id, model_params_setting, prompt, state_list=None): if state_list is None: state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value, State.REVOKE.value, @@ -87,7 +87,7 @@ def generate_related_by_document_id(document_id, model_id, prompt, state_list=No ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.GENERATE_PROBLEM, State.STARTED) - llm_model = get_llm_model(model_id) + llm_model = get_llm_model(model_id, model_params_setting) # 生成问题函数 generate_problem = get_generate_problem(llm_model, prompt, @@ -110,7 +110,7 @@ def generate_related_by_document_id(document_id, model_id, prompt, state_list=No @celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:generate_related_by_paragraph_list') -def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, prompt): +def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, model_params_setting, prompt): try: is_the_task_interrupted = get_is_the_task_interrupted(document_id) if is_the_task_interrupted(): @@ -121,7 +121,7 @@ def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_ ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.GENERATE_PROBLEM, State.STARTED) - llm_model = get_llm_model(model_id) + llm_model = get_llm_model(model_id, model_params_setting) # 生成问题函数 generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status( document_id)) diff --git a/ui/src/components/generate-related-dialog/index.vue b/ui/src/components/generate-related-dialog/index.vue index d41826316..6d944223d 100644 --- a/ui/src/components/generate-related-dialog/index.vue +++ b/ui/src/components/generate-related-dialog/index.vue @@ -28,7 +28,21 @@

{{ $t('views.document.generateQuestion.tip4') }}

- + + + +