diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py index 510ef1339..bfaf9f17b 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -10,18 +10,20 @@ def get_max_tokens_keyword(model_name): :param model_name: 模型名称字符串 :return: 对应的 max_tokens 关键字字符串 """ - if 'amazon' in model_name: - return 'maxTokenCount' - elif 'anthropic' in model_name: - return 'max_tokens_to_sample' - elif 'ai21' in model_name: + maxTokens = ["ai21.j2-ultra-v1", "ai21.j2-mid-v1"] + # max_tokens_to_sample = ["anthropic.claude-v2:1", "anthropic.claude-v2", "anthropic.claude-instant-v1"] + maxTokenCount = ["amazon.titan-text-lite-v1", "amazon.titan-text-express-v1"] + max_new_tokens = [ + "us.meta.llama3-2-1b-instruct-v1:0", "us.meta.llama3-2-3b-instruct-v1:0", "us.meta.llama3-2-11b-instruct-v1:0", + "us.meta.llama3-2-90b-instruct-v1:0"] + if model_name in maxTokens: return 'maxTokens' - elif 'cohere' in model_name or 'mistral' in model_name: - return 'max_tokens' - elif 'meta' in model_name: - return 'max_gen_len' + elif model_name in maxTokenCount: + return 'maxTokenCount' + elif model_name in max_new_tokens: + return 'max_new_tokens' else: - raise ValueError("Unsupported model supplier in model_name.") + return 'max_tokens' class BedrockModel(MaxKBBaseModel, BedrockChat): diff --git a/pyproject.toml b/pyproject.toml index 785dae1cc..5cd23bbf3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ cgroups = "^0.1.0" wasm-exec = "^0.1.9" boto3 = "^1.34.151" -langchain-aws = "^0.1.13" tencentcloud-sdk-python = "^3.0.1205" xinference-client = "^0.14.0.post1" psutil = "^6.0.0"