From 33da6073028ee161979de4f13fce9cfdb7b9dbe3 Mon Sep 17 00:00:00 2001 From: wxg0103 <727495428@qq.com> Date: Tue, 2 Dec 2025 14:27:38 +0800 Subject: [PATCH] feat: implement AWS Bedrock Vision-Language and Reranker models with credential validation --- .../impl/response/openai_to_response.py | 6 +- .../aws_bedrock_model_provider.py | 50 ++++++++++ .../credential/image.py | 76 ++++++++++++++++ .../credential/reranker.py | 67 ++++++++++++++ .../aws_bedrock_model_provider/model/image.py | 91 +++++++++++++++++++ .../model/reranker.py | 80 ++++++++++++++++ 6 files changed, 367 insertions(+), 3 deletions(-) create mode 100644 apps/models_provider/impl/aws_bedrock_model_provider/credential/image.py create mode 100644 apps/models_provider/impl/aws_bedrock_model_provider/credential/reranker.py create mode 100644 apps/models_provider/impl/aws_bedrock_model_provider/model/image.py create mode 100644 apps/models_provider/impl/aws_bedrock_model_provider/model/reranker.py diff --git a/apps/common/handle/impl/response/openai_to_response.py b/apps/common/handle/impl/response/openai_to_response.py index 22d4b3bc5..b4eda3625 100644 --- a/apps/common/handle/impl/response/openai_to_response.py +++ b/apps/common/handle/impl/response/openai_to_response.py @@ -20,7 +20,7 @@ from common.handle.base_to_response import BaseToResponse class OpenaiToResponse(BaseToResponse): - def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, + def to_block_response(self, chat_id, chat_record_id, content, is_end, prompt_tokens, completion_tokens, other_params: dict = None, _status=status.HTTP_200_OK): if other_params is None: @@ -37,8 +37,8 @@ class OpenaiToResponse(BaseToResponse): return JsonResponse(data=data, status=_status) def to_stream_chunk_response(self, chat_id, chat_record_id, node_id, up_node_id_list, content, is_end, - completion_tokens, - prompt_tokens, other_params: dict = None): + prompt_tokens, + completion_tokens, other_params: dict = None): if other_params is None: other_params = {} chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk', diff --git a/apps/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py b/apps/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py index 57c52fc9c..3ff8efa83 100644 --- a/apps/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py +++ b/apps/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py @@ -7,12 +7,17 @@ from models_provider.base_model_provider import ( IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage ) from models_provider.impl.aws_bedrock_model_provider.credential.embedding import BedrockEmbeddingCredential +from models_provider.impl.aws_bedrock_model_provider.credential.image import BedrockVLModelCredential from models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential +from models_provider.impl.aws_bedrock_model_provider.credential.reranker import BedrockRerankerCredential from models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel +from models_provider.impl.aws_bedrock_model_provider.model.image import BedrockVLModel from models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel from maxkb.conf import PROJECT_DIR from django.utils.translation import gettext as _ +from models_provider.impl.aws_bedrock_model_provider.model.reranker import BedrockRerankerModel + def _create_model_info(model_name, description, model_type, credential_class, model_class): return ModelInfo( @@ -127,11 +132,56 @@ def _initialize_model_info(): ), ] + reranker_model_info_list = [ + _create_model_info( + 'amazon.rerank-v1:0', + '', + ModelTypeConst.RERANKER, + BedrockRerankerCredential, + BedrockRerankerModel + ), + _create_model_info( + 'cohere.rerank-v3-5:0', + '', + ModelTypeConst.RERANKER, + BedrockRerankerCredential, + BedrockRerankerModel + ) + ] + vl_model_info_list = [ + + _create_model_info( + 'global.anthropic.claude-sonnet-4-5-20250929-v1:0', + '', + ModelTypeConst.IMAGE, + BedrockVLModelCredential, + BedrockVLModel + ), + _create_model_info( + 'us.anthropic.claude-sonnet-4-5-20250929-v1:0', + '', + ModelTypeConst.IMAGE, + BedrockVLModelCredential, + BedrockVLModel + ), + _create_model_info( + 'global.anthropic.claude-haiku-4-5-20251001-v1:0', + '', + ModelTypeConst.IMAGE, + BedrockVLModelCredential, + BedrockVLModel + ), + ] + model_info_manage = ModelInfoManage.builder() \ .append_model_info_list(model_info_list) \ .append_default_model_info(model_info_list[0]) \ .append_model_info_list(embedded_model_info_list) \ .append_default_model_info(embedded_model_info_list[0]) \ + .append_model_info_list(vl_model_info_list) \ + .append_default_model_info(vl_model_info_list[0]) \ + .append_model_info_list(reranker_model_info_list) \ + .append_default_model_info(reranker_model_info_list[0]) \ .build() return model_info_manage diff --git a/apps/models_provider/impl/aws_bedrock_model_provider/credential/image.py b/apps/models_provider/impl/aws_bedrock_model_provider/credential/image.py new file mode 100644 index 000000000..a2bc3092c --- /dev/null +++ b/apps/models_provider/impl/aws_bedrock_model_provider/credential/image.py @@ -0,0 +1,76 @@ +from typing import Dict + +from django.utils.translation import gettext_lazy as _, gettext +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from models_provider.base_model_provider import ValidCode, BaseModelCredential +from common.utils.logger import maxkb_logger + + +class BedrockImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel(_('Temperature'), + _('Higher values make the output more random, while lower values make it more focused and deterministic')), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel(_('Output the maximum Tokens'), + _('Specify the maximum number of tokens that the model can generate')), + required=True, default_value=1024, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class BedrockVLModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(mt.get('value') == model_type for mt in model_type_list): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + gettext('{model_type} Model type is not supported').format(model_type=model_type)) + return False + + required_keys = ['region_name', 'access_key_id', 'secret_access_key'] + if not all(key in model_credential for key in required_keys): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + gettext('The following fields are required: {keys}').format( + keys=", ".join(required_keys))) + return False + + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + model.invoke([HumanMessage(content=gettext('Hello'))]) + except AppApiException: + raise + except Exception as e: + maxkb_logger.error(f'Exception: {e}', exc_info=True) + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + gettext( + 'Verification failed, please check whether the parameters are correct: {error}').format( + error=str(e))) + return False + + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))} + + region_name = forms.TextInputField('Region Name', required=True) + access_key_id = forms.TextInputField('Access Key ID', required=True) + secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) + base_url = forms.TextInputField('Proxy URL', required=False) + + def get_model_params_setting_form(self, model_name): + return BedrockImageModelParams() diff --git a/apps/models_provider/impl/aws_bedrock_model_provider/credential/reranker.py b/apps/models_provider/impl/aws_bedrock_model_provider/credential/reranker.py new file mode 100644 index 000000000..46a3ed9e3 --- /dev/null +++ b/apps/models_provider/impl/aws_bedrock_model_provider/credential/reranker.py @@ -0,0 +1,67 @@ +import traceback +from typing import Dict + +from django.utils.translation import gettext_lazy as _, gettext +from langchain_core.documents import Document + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from models_provider.base_model_provider import ValidCode, BaseModelCredential + + +class BedrockRerankerModelParams(BaseForm): + top_n = forms.SliderField(TooltipLabel(_('Top N'), + _('Number of top documents to return after reranking')), + required=True, default_value=3, + _min=1, + _max=20, + _step=1, + precision=0) + + +class BedrockRerankerCredential(BaseForm, BaseModelCredential): + access_key_id = forms.PasswordInputField(_('Access Key ID'), required=True) + secret_access_key = forms.PasswordInputField(_('Secret Access Key'), required=True) + region_name = forms.TextInputField(_('Region Name'), required=True, default_value='us-east-1') + base_url = forms.TextInputField(_('Base URL'), required=False) + + def is_valid(self, model_type: str, model_name: str, model_credential: Dict[str, object], model_params, + provider, + raise_exception: bool = False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, _('Model type is not supported')) + + for key in ['access_key_id', 'secret_access_key', 'region_name']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, _('%(key)s is required') % {'key': key}) + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential, **model_params) + # Use top_n=1 for validation since we only have 1 test document + test_docs = [ + Document(page_content=str(_('Hello'))), + Document(page_content=str(_('World'))), + Document(page_content=str(_('Test'))) + ] + model.compress_documents(test_docs, str(_('Hello'))) + except Exception as e: + traceback.print_exc() + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, + _('Verification failed, please check whether the parameters are correct: %(error)s') % {'error': str(e)}) + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'access_key_id': super().encryption(model.get('access_key_id', '')), + 'secret_access_key': super().encryption(model.get('secret_access_key', ''))} + + def get_model_params_setting_form(self, model_name): + return BedrockRerankerModelParams() diff --git a/apps/models_provider/impl/aws_bedrock_model_provider/model/image.py b/apps/models_provider/impl/aws_bedrock_model_provider/model/image.py new file mode 100644 index 000000000..9ab812a54 --- /dev/null +++ b/apps/models_provider/impl/aws_bedrock_model_provider/model/image.py @@ -0,0 +1,91 @@ +# coding=utf-8 +""" + @project: MaxKB + @file: image.py + @desc: AWS Bedrock Vision-Language Model Implementation +""" +from typing import Dict, List + +from botocore.config import Config +from langchain_aws import ChatBedrock +from langchain_core.messages import BaseMessage, get_buffer_string + +from common.config.tokenizer_manage_config import TokenizerManage +from models_provider.base_model_provider import MaxKBBaseModel +from models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials + + +class BedrockVLModel(MaxKBBaseModel, ChatBedrock): + """ + AWS Bedrock Vision-Language Model + Supports Claude 3 models with vision capabilities (Haiku, Sonnet, Opus) + """ + + @staticmethod + def is_cache_model(): + return False + + def __init__(self, model_id: str, region_name: str, credentials_profile_name: str, + streaming: bool = False, config: Config = None, **kwargs): + super().__init__( + model_id=model_id, + region_name=region_name, + credentials_profile_name=credentials_profile_name, + streaming=streaming, + config=config, + **kwargs + ) + + @classmethod + def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str], + **model_kwargs) -> 'BedrockVLModel': + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + + config = {} + # Check if proxy URL is provided + if 'base_url' in model_credential and model_credential['base_url']: + proxy_url = model_credential['base_url'] + config = Config( + proxies={ + 'http': proxy_url, + 'https': proxy_url + }, + connect_timeout=60, + read_timeout=60 + ) + _update_aws_credentials( + model_credential['access_key_id'], + model_credential['access_key_id'], + model_credential['secret_access_key'] + ) + + return cls( + model_id=model_name, + region_name=model_credential['region_name'], + credentials_profile_name=model_credential['access_key_id'], + streaming=model_kwargs.pop('streaming', True), + model_kwargs=optional_params, + config=config + ) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + """ + Get the number of tokens from messages + Falls back to local tokenizer if the model's tokenizer fails + """ + try: + return super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + """ + Get the number of tokens from text + Falls back to local tokenizer if the model's tokenizer fails + """ + try: + return super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/apps/models_provider/impl/aws_bedrock_model_provider/model/reranker.py b/apps/models_provider/impl/aws_bedrock_model_provider/model/reranker.py new file mode 100644 index 000000000..fb4743f5f --- /dev/null +++ b/apps/models_provider/impl/aws_bedrock_model_provider/model/reranker.py @@ -0,0 +1,80 @@ +import os +import re +from typing import Dict, List, Sequence, Optional, Any + +from botocore.config import Config +from langchain_aws import BedrockRerank +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from pydantic import ConfigDict + +from models_provider.base_model_provider import MaxKBBaseModel +from models_provider.impl.aws_bedrock_model_provider.model.llm import _update_aws_credentials + + +class BedrockRerankerModel(MaxKBBaseModel, BaseDocumentCompressor): + model_config = ConfigDict(arbitrary_types_allowed=True) + + model_id: Optional[str] = None + model_arn: Optional[str] = None + region_name: Optional[str] = None + credentials_profile_name: Optional[str] = None + aws_access_key_id: Optional[str] = None + aws_secret_access_key: Optional[str] = None + config: Optional[Any] = None + top_n: Optional[int] = 3 + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], + **model_kwargs) -> 'BedrockRerankerModel': + top_n = model_kwargs.get('top_n', 3) + region_name = model_credential['region_name'] + model_arn = f"arn:aws:bedrock:{region_name}::foundation-model/{model_name}" + + config = None + if 'base_url' in model_credential and model_credential['base_url']: + proxy_url = model_credential['base_url'] + config = Config( + proxies={ + 'http': proxy_url, + 'https': proxy_url + }, + connect_timeout=60, + read_timeout=60 + ) + + _update_aws_credentials(model_credential['access_key_id'], model_credential['access_key_id'], + model_credential['secret_access_key']) + + return BedrockRerankerModel( + model_id=model_name, + model_arn=model_arn, + region_name=region_name, + credentials_profile_name=model_credential['access_key_id'], + aws_access_key_id=model_credential['access_key_id'], + aws_secret_access_key=model_credential['secret_access_key'], + config=config, + top_n=top_n + ) + + def compress_documents(self, documents: Sequence[Document], query: str, + callbacks: Optional[Callbacks] = None) -> Sequence[Document]: + """Compress documents using Bedrock reranking.""" + if not documents: + return [] + + reranker = BedrockRerank( + model_arn=self.model_arn, + region_name=self.region_name, + credentials_profile_name=self.credentials_profile_name, + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + config=self.config, + top_n=self.top_n + ) + return reranker.compress_documents(documents, query, callbacks) +