diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index d8215e983..eecfc6a0f 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -96,7 +96,7 @@ class NoReferencesSetting(serializers.Serializer): def valid_model_params_setting(model_id, model_params_setting): - if model_id is None: + if model_id is None or model_params_setting is None or len(model_params_setting.keys()) == 0: return model = QuerySet(Model).filter(id=model_id).first() credential = get_model_credential(model.provider, model.model_type, model.model_name) @@ -416,7 +416,7 @@ class ApplicationSerializer(serializers.Serializer): model_setting=application.get('model_setting'), problem_optimization=application.get('problem_optimization'), type=ApplicationTypeChoices.SIMPLE, - model_params_setting=application.get('model_params_setting',{}), + model_params_setting=application.get('model_params_setting', {}), work_flow={} ) @@ -697,7 +697,9 @@ class ApplicationSerializer(serializers.Serializer): ApplicationSerializer.Edit(data=instance).is_valid( raise_exception=True) application_id = self.data.get("application_id") - valid_model_params_setting(instance.get('model_id'), instance.get('model_params_setting')) + valid_model_params_setting(instance.get('model_id'), + instance.get('model_params_setting')) + application = QuerySet(Application).get(id=application_id) if instance.get('model_id') is None or len(instance.get('model_id')) == 0: application.model_id = None diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index 40f90fa0a..ce9d624c2 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -43,3 +43,17 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): custom_get_token_ids=custom_get_token_ids ) return azure_chat_open_ai + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + try: + super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + try: + super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text))