From 96ac12ea31a18369ac97aa8a06f39948ecdeb4e0 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 Date: Thu, 29 Aug 2024 15:09:01 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dopenai=E4=BE=9B?= =?UTF-8?q?=E5=BA=94=E5=95=86=E8=AE=A1=E7=AE=97tokens=E9=94=99=E8=AF=AF,?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=97=A7=E7=89=88=E6=9C=AC=E5=BA=94=E7=94=A8?= =?UTF-8?q?=E7=BC=96=E8=BE=91=E9=A1=B5=E9=9D=A2=E7=9B=B4=E6=8E=A5=E6=8A=A5?= =?UTF-8?q?=E9=94=99=E6=8F=90=E7=A4=BA=E9=97=AE=E9=A2=98=E4=B8=BA=E5=BF=85?= =?UTF-8?q?=E5=A1=AB=E5=8F=82=E6=95=B0=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../serializers/application_serializers.py | 8 +++++--- .../impl/openai_model_provider/model/llm.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/apps/application/serializers/application_serializers.py b/apps/application/serializers/application_serializers.py index d8215e983..eecfc6a0f 100644 --- a/apps/application/serializers/application_serializers.py +++ b/apps/application/serializers/application_serializers.py @@ -96,7 +96,7 @@ class NoReferencesSetting(serializers.Serializer): def valid_model_params_setting(model_id, model_params_setting): - if model_id is None: + if model_id is None or model_params_setting is None or len(model_params_setting.keys()) == 0: return model = QuerySet(Model).filter(id=model_id).first() credential = get_model_credential(model.provider, model.model_type, model.model_name) @@ -416,7 +416,7 @@ class ApplicationSerializer(serializers.Serializer): model_setting=application.get('model_setting'), problem_optimization=application.get('problem_optimization'), type=ApplicationTypeChoices.SIMPLE, - model_params_setting=application.get('model_params_setting',{}), + model_params_setting=application.get('model_params_setting', {}), work_flow={} ) @@ -697,7 +697,9 @@ class ApplicationSerializer(serializers.Serializer): ApplicationSerializer.Edit(data=instance).is_valid( raise_exception=True) application_id = self.data.get("application_id") - valid_model_params_setting(instance.get('model_id'), instance.get('model_params_setting')) + valid_model_params_setting(instance.get('model_id'), + instance.get('model_params_setting')) + application = QuerySet(Application).get(id=application_id) if instance.get('model_id') is None or len(instance.get('model_id')) == 0: application.model_id = None diff --git a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py index 40f90fa0a..ce9d624c2 100644 --- a/apps/setting/models_provider/impl/openai_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -43,3 +43,17 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): custom_get_token_ids=custom_get_token_ids ) return azure_chat_open_ai + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + try: + super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + try: + super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text))