fix: 修复openai供应商计算tokens错误,修复旧版本应用编辑页面直接报错提示问题为必填参数错误

This commit is contained in:
shaohuzhang1 2024-08-29 15:09:01 +08:00 committed by shaohuzhang1
parent b627b6638c
commit 96ac12ea31
2 changed files with 19 additions and 3 deletions

View File

@ -96,7 +96,7 @@ class NoReferencesSetting(serializers.Serializer):
def valid_model_params_setting(model_id, model_params_setting):
if model_id is None:
if model_id is None or model_params_setting is None or len(model_params_setting.keys()) == 0:
return
model = QuerySet(Model).filter(id=model_id).first()
credential = get_model_credential(model.provider, model.model_type, model.model_name)
@ -416,7 +416,7 @@ class ApplicationSerializer(serializers.Serializer):
model_setting=application.get('model_setting'),
problem_optimization=application.get('problem_optimization'),
type=ApplicationTypeChoices.SIMPLE,
model_params_setting=application.get('model_params_setting',{}),
model_params_setting=application.get('model_params_setting', {}),
work_flow={}
)
@ -697,7 +697,9 @@ class ApplicationSerializer(serializers.Serializer):
ApplicationSerializer.Edit(data=instance).is_valid(
raise_exception=True)
application_id = self.data.get("application_id")
valid_model_params_setting(instance.get('model_id'), instance.get('model_params_setting'))
valid_model_params_setting(instance.get('model_id'),
instance.get('model_params_setting'))
application = QuerySet(Application).get(id=application_id)
if instance.get('model_id') is None or len(instance.get('model_id')) == 0:
application.model_id = None

View File

@ -43,3 +43,17 @@ class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI):
custom_get_token_ids=custom_get_token_ids
)
return azure_chat_open_ai
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
try:
super().get_num_tokens_from_messages(messages)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
def get_num_tokens(self, text: str) -> int:
try:
super().get_num_tokens(text)
except Exception as e:
tokenizer = TokenizerManage.get_tokenizer()
return len(tokenizer.encode(text))