mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
fix: aws credentials_profile_name
This commit is contained in:
parent
de97de64b2
commit
85868c13ce
|
|
@ -11,21 +11,6 @@ from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding imp
|
|||
|
||||
class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
@staticmethod
|
||||
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
|
||||
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
|
||||
|
||||
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
|
||||
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
|
||||
content = re.sub(pattern, '', content, flags=re.DOTALL)
|
||||
|
||||
if not re.search(rf'\[{profile_name}\]', content):
|
||||
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
|
||||
|
||||
with open(credentials_path, 'w') as file:
|
||||
file.write(content)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
model_type_list = provider.get_model_type_list()
|
||||
|
|
@ -41,9 +26,6 @@ class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):
|
|||
return False
|
||||
|
||||
try:
|
||||
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
|
||||
model_credential['secret_access_key'])
|
||||
model_credential['credentials_profile_name'] = 'aws-profile'
|
||||
model: BedrockEmbeddingModel = provider.get_model(model_type, model_name, model_credential)
|
||||
aa = model.embed_query('你好')
|
||||
print(aa)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
|
@ -29,20 +28,7 @@ class BedrockLLMModelParams(BaseForm):
|
|||
|
||||
class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
|
||||
|
||||
@staticmethod
|
||||
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
|
||||
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
|
||||
|
||||
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
|
||||
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
|
||||
content = re.sub(pattern, '', content, flags=re.DOTALL)
|
||||
|
||||
if not re.search(rf'\[{profile_name}\]', content):
|
||||
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
|
||||
|
||||
with open(credentials_path, 'w') as file:
|
||||
file.write(content)
|
||||
|
||||
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
|
||||
raise_exception=False):
|
||||
|
|
@ -59,9 +45,6 @@ class BedrockLLMModelCredential(BaseForm, BaseModelCredential):
|
|||
return False
|
||||
|
||||
try:
|
||||
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
|
||||
model_credential['secret_access_key'])
|
||||
model_credential['credentials_profile_name'] = 'aws-profile'
|
||||
model = provider.get_model(model_type, model_name, model_credential, **model_params)
|
||||
model.invoke([HumanMessage(content='你好')])
|
||||
except AppApiException:
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ from langchain_community.embeddings import BedrockEmbeddings
|
|||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
from typing import Dict, List
|
||||
|
||||
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials
|
||||
|
||||
|
||||
class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings):
|
||||
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
|
||||
|
|
@ -13,10 +15,12 @@ class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings):
|
|||
@classmethod
|
||||
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
|
||||
**model_kwargs) -> 'BedrockModel':
|
||||
_update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
|
||||
model_credential['secret_access_key'])
|
||||
return cls(
|
||||
model_id=model_name,
|
||||
region_name=model_credential['region_name'],
|
||||
credentials_profile_name=model_credential['credentials_profile_name'],
|
||||
credentials_profile_name=model_credential['access_key_id'],
|
||||
)
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from typing import List, Dict
|
||||
|
||||
import os
|
||||
import re
|
||||
from botocore.config import Config
|
||||
from langchain_community.chat_models import BedrockChat
|
||||
from setting.models_provider.base_model_provider import MaxKBBaseModel
|
||||
|
|
@ -57,12 +58,29 @@ class BedrockModel(MaxKBBaseModel, BedrockChat):
|
|||
connect_timeout=60,
|
||||
read_timeout=60
|
||||
)
|
||||
_update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'],
|
||||
model_credential['secret_access_key'])
|
||||
|
||||
return cls(
|
||||
model_id=model_name,
|
||||
region_name=model_credential['region_name'],
|
||||
credentials_profile_name=model_credential['credentials_profile_name'],
|
||||
credentials_profile_name=model_credential['access_key_id'],
|
||||
streaming=model_kwargs.pop('streaming', True),
|
||||
model_kwargs=optional_params,
|
||||
config=config
|
||||
)
|
||||
|
||||
|
||||
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
|
||||
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
|
||||
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)
|
||||
|
||||
content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
|
||||
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
|
||||
content = re.sub(pattern, '', content, flags=re.DOTALL)
|
||||
|
||||
if not re.search(rf'\[{profile_name}\]', content):
|
||||
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"
|
||||
|
||||
with open(credentials_path, 'w') as file:
|
||||
file.write(content)
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ class ModelSerializer(serializers.Serializer):
|
|||
source_encryption_model_credential = model_credential.encryption_dict(source_model_credential)
|
||||
if credential is not None:
|
||||
for k in source_encryption_model_credential.keys():
|
||||
if credential[k] == source_encryption_model_credential[k]:
|
||||
if k in credential and credential[k] == source_encryption_model_credential[k]:
|
||||
credential[k] = source_model_credential[k]
|
||||
return credential, model_credential, provider_handler
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue