diff --git a/apps/knowledge/serializers/document.py b/apps/knowledge/serializers/document.py index fa28ead54..a2270a73b 100644 --- a/apps/knowledge/serializers/document.py +++ b/apps/knowledge/serializers/document.py @@ -1305,6 +1305,7 @@ class DocumentSerializers(serializers.Serializer): document_id_list = instance.get("document_id_list") model_id = instance.get("model_id") prompt = instance.get("prompt") + model_params_setting = instance.get("model_params_setting") state_list = instance.get('state_list') ListenerManagement.update_status( QuerySet(Document).filter(id__in=document_id_list), @@ -1327,7 +1328,7 @@ class DocumentSerializers(serializers.Serializer): QuerySet(Document).filter(id__in=document_id_list))() try: for document_id in document_id_list: - generate_related_by_document_id.delay(document_id, model_id, prompt, state_list) + generate_related_by_document_id.delay(document_id, model_id, model_params_setting, prompt, state_list) except AlreadyQueued as e: pass diff --git a/apps/knowledge/serializers/knowledge.py b/apps/knowledge/serializers/knowledge.py index 8da608253..266b10f4a 100644 --- a/apps/knowledge/serializers/knowledge.py +++ b/apps/knowledge/serializers/knowledge.py @@ -267,6 +267,7 @@ class KnowledgeSerializer(serializers.Serializer): knowledge_id = self.data.get('knowledge_id') model_id = instance.get("model_id") prompt = instance.get("prompt") + model_params_setting = instance.get("model_params_setting") state_list = instance.get('state_list') ListenerManagement.update_status( QuerySet(Document).filter(knowledge_id=knowledge_id), @@ -285,7 +286,7 @@ class KnowledgeSerializer(serializers.Serializer): ) ListenerManagement.get_aggregation_document_status_by_knowledge_id(knowledge_id)() try: - generate_related_by_knowledge_id.delay(knowledge_id, model_id, prompt, state_list) + generate_related_by_knowledge_id.delay(knowledge_id, model_id, model_params_setting, prompt, state_list) except AlreadyQueued as e: raise AppApiException(500, _('Failed to send the vectorization task, please try again later!')) diff --git a/apps/knowledge/serializers/paragraph.py b/apps/knowledge/serializers/paragraph.py index dcc09f978..28f1aa10b 100644 --- a/apps/knowledge/serializers/paragraph.py +++ b/apps/knowledge/serializers/paragraph.py @@ -480,6 +480,7 @@ class ParagraphSerializers(serializers.Serializer): paragraph_id_list = instance.get("paragraph_id_list") model_id = instance.get("model_id") prompt = instance.get("prompt") + model_params_setting = instance.get("model_params_setting") document_id = self.data.get('document_id') ListenerManagement.update_status( QuerySet(Document).filter(id=document_id), @@ -493,7 +494,7 @@ class ParagraphSerializers(serializers.Serializer): ) ListenerManagement.get_aggregation_document_status(document_id)() try: - generate_related_by_paragraph_id_list.delay(document_id, paragraph_id_list, model_id, prompt) + generate_related_by_paragraph_id_list.delay(document_id, paragraph_id_list, model_id, model_params_setting, prompt) except AlreadyQueued as e: raise AppApiException(500, _('The task is being executed, please do not send it again.')) diff --git a/apps/knowledge/task/generate.py b/apps/knowledge/task/generate.py index e69a2ea85..9a86a7ea2 100644 --- a/apps/knowledge/task/generate.py +++ b/apps/knowledge/task/generate.py @@ -18,9 +18,9 @@ from models_provider.tools import get_model from ops import celery_app -def get_llm_model(model_id): +def get_llm_model(model_id, model_params_setting=None): model = QuerySet(Model).filter(id=model_id).first() - return ModelManage.get_model(model_id, lambda _id: get_model(model)) + return ModelManage.get_model(model_id, lambda _id: get_model(model, **(model_params_setting or {}))) def generate_problem_by_paragraph(paragraph, llm_model, prompt): @@ -64,18 +64,18 @@ def get_is_the_task_interrupted(document_id): @celery_app.task(base=QueueOnce, once={'keys': ['knowledge_id']}, name='celery:generate_related_by_knowledge') -def generate_related_by_knowledge_id(knowledge_id, model_id, prompt, state_list=None): +def generate_related_by_knowledge_id(knowledge_id, model_id, model_params_setting, prompt, state_list=None): document_list = QuerySet(Document).filter(knowledge_id=knowledge_id) for document in document_list: try: - generate_related_by_document_id.delay(document.id, model_id, prompt, state_list) + generate_related_by_document_id.delay(document.id, model_id, model_params_setting, prompt, state_list) except Exception as e: pass @celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:generate_related_by_document') -def generate_related_by_document_id(document_id, model_id, prompt, state_list=None): +def generate_related_by_document_id(document_id, model_id, model_params_setting, prompt, state_list=None): if state_list is None: state_list = [State.PENDING.value, State.STARTED.value, State.SUCCESS.value, State.FAILURE.value, State.REVOKE.value, @@ -87,7 +87,7 @@ def generate_related_by_document_id(document_id, model_id, prompt, state_list=No ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.GENERATE_PROBLEM, State.STARTED) - llm_model = get_llm_model(model_id) + llm_model = get_llm_model(model_id, model_params_setting) # 生成问题函数 generate_problem = get_generate_problem(llm_model, prompt, @@ -110,7 +110,7 @@ def generate_related_by_document_id(document_id, model_id, prompt, state_list=No @celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:generate_related_by_paragraph_list') -def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, prompt): +def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_id, model_params_setting, prompt): try: is_the_task_interrupted = get_is_the_task_interrupted(document_id) if is_the_task_interrupted(): @@ -121,7 +121,7 @@ def generate_related_by_paragraph_id_list(document_id, paragraph_id_list, model_ ListenerManagement.update_status(QuerySet(Document).filter(id=document_id), TaskType.GENERATE_PROBLEM, State.STARTED) - llm_model = get_llm_model(model_id) + llm_model = get_llm_model(model_id, model_params_setting) # 生成问题函数 generate_problem = get_generate_problem(llm_model, prompt, ListenerManagement.get_aggregation_document_status( document_id)) diff --git a/ui/src/components/generate-related-dialog/index.vue b/ui/src/components/generate-related-dialog/index.vue index d41826316..6d944223d 100644 --- a/ui/src/components/generate-related-dialog/index.vue +++ b/ui/src/components/generate-related-dialog/index.vue @@ -28,7 +28,21 @@

{{ $t('views.document.generateQuestion.tip4') }}

- + + + +