refactor: 优化aws大模型的参数

This commit is contained in:
wxg0103 2024-09-27 15:15:31 +08:00
parent c78f216102
commit f869fc5f26
2 changed files with 12 additions and 11 deletions

View File

@ -10,18 +10,20 @@ def get_max_tokens_keyword(model_name):
:param model_name: 模型名称字符串
:return: 对应的 max_tokens 关键字字符串
"""
if 'amazon' in model_name:
return 'maxTokenCount'
elif 'anthropic' in model_name:
return 'max_tokens_to_sample'
elif 'ai21' in model_name:
maxTokens = ["ai21.j2-ultra-v1", "ai21.j2-mid-v1"]
# max_tokens_to_sample = ["anthropic.claude-v2:1", "anthropic.claude-v2", "anthropic.claude-instant-v1"]
maxTokenCount = ["amazon.titan-text-lite-v1", "amazon.titan-text-express-v1"]
max_new_tokens = [
"us.meta.llama3-2-1b-instruct-v1:0", "us.meta.llama3-2-3b-instruct-v1:0", "us.meta.llama3-2-11b-instruct-v1:0",
"us.meta.llama3-2-90b-instruct-v1:0"]
if model_name in maxTokens:
return 'maxTokens'
elif 'cohere' in model_name or 'mistral' in model_name:
return 'max_tokens'
elif 'meta' in model_name:
return 'max_gen_len'
elif model_name in maxTokenCount:
return 'maxTokenCount'
elif model_name in max_new_tokens:
return 'max_new_tokens'
else:
raise ValueError("Unsupported model supplier in model_name.")
return 'max_tokens'
class BedrockModel(MaxKBBaseModel, BedrockChat):

View File

@ -52,7 +52,6 @@ cgroups = "^0.1.0"
wasm-exec = "^0.1.9"
boto3 = "^1.34.151"
langchain-aws = "^0.1.13"
tencentcloud-sdk-python = "^3.0.1205"
xinference-client = "^0.14.0.post1"
psutil = "^6.0.0"