From f869fc5f260c3ce80e1d6da5c64b8b6a44a22e0a Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Fri, 27 Sep 2024 15:15:31 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96aws=E5=A4=A7?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=9A=84=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../aws_bedrock_model_provider/model/llm.py | 22 ++++++++++--------- pyproject.toml | 1 - 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py index 510ef1339..bfaf9f17b 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -10,18 +10,20 @@ def get_max_tokens_keyword(model_name): :param model_name: 模型名称字符串 :return: 对应的 max_tokens 关键字字符串 """ - if 'amazon' in model_name: - return 'maxTokenCount' - elif 'anthropic' in model_name: - return 'max_tokens_to_sample' - elif 'ai21' in model_name: + maxTokens = ["ai21.j2-ultra-v1", "ai21.j2-mid-v1"] + # max_tokens_to_sample = ["anthropic.claude-v2:1", "anthropic.claude-v2", "anthropic.claude-instant-v1"] + maxTokenCount = ["amazon.titan-text-lite-v1", "amazon.titan-text-express-v1"] + max_new_tokens = [ + "us.meta.llama3-2-1b-instruct-v1:0", "us.meta.llama3-2-3b-instruct-v1:0", "us.meta.llama3-2-11b-instruct-v1:0", + "us.meta.llama3-2-90b-instruct-v1:0"] + if model_name in maxTokens: return 'maxTokens' - elif 'cohere' in model_name or 'mistral' in model_name: - return 'max_tokens' - elif 'meta' in model_name: - return 'max_gen_len' + elif model_name in maxTokenCount: + return 'maxTokenCount' + elif model_name in max_new_tokens: + return 'max_new_tokens' else: - raise ValueError("Unsupported model supplier in model_name.") + return 'max_tokens' class BedrockModel(MaxKBBaseModel, BedrockChat): diff --git a/pyproject.toml b/pyproject.toml index 785dae1cc..5cd23bbf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ cgroups = "^0.1.0" wasm-exec = "^0.1.9" boto3 = "^1.34.151" -langchain-aws = "^0.1.13" tencentcloud-sdk-python = "^3.0.1205" xinference-client = "^0.14.0.post1" psutil = "^6.0.0"