diff --git a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py index a26c6f52f..3b96f15cd 100644 --- a/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py +++ b/apps/application/flow/step_node/image_understand_step_node/impl/base_image_understand_node.py @@ -12,6 +12,7 @@ from application.flow.i_step_node import NodeResult, INode from application.flow.step_node.image_understand_step_node.i_image_understand_node import IImageUnderstandNode from dataset.models import File from setting.models_provider.tools import get_model_instance_by_model_user_id +from imghdr import what def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): @@ -59,8 +60,9 @@ def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, wor def file_id_to_base64(file_id: str): file = QuerySet(File).filter(id=file_id).first() - base64_image = base64.b64encode(file.get_byte()).decode("utf-8") - return base64_image + file_bytes = file.get_byte() + base64_image = base64.b64encode(file_bytes).decode("utf-8") + return [base64_image, what(None, file_bytes.tobytes())] class BaseImageUnderstandNode(IImageUnderstandNode): @@ -77,7 +79,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode): # 处理不正确的参数 if image is None or not isinstance(image, list): image = [] - + print(model_params_setting) image_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), **model_params_setting) # 执行详情中的历史消息不需要图片内容 history_message = self.get_history_message_for_details(history_chat_record, dialogue_number) @@ -152,7 +154,7 @@ class BaseImageUnderstandNode(IImageUnderstandNode): return HumanMessage( content=[ {'type': 'text', 'text': data['question']}, - *[{'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}} for + *[{'type': 'image_url', 'image_url': {'url': f'data:image/{base64_image[1]};base64,{base64_image[0]}'}} for base64_image in image_base64_list] ]) return HumanMessage(content=chat_record.problem_text) @@ -167,8 +169,10 @@ class BaseImageUnderstandNode(IImageUnderstandNode): for img in image: file_id = img['file_id'] file = QuerySet(File).filter(id=file_id).first() - base64_image = base64.b64encode(file.get_byte()).decode("utf-8") - images.append({'type': 'image_url', 'image_url': {'url': f'data:image/jpeg;base64,{base64_image}'}}) + image_bytes = file.get_byte() + base64_image = base64.b64encode(image_bytes).decode("utf-8") + image_format = what(None, image_bytes.tobytes()) + images.append({'type': 'image_url', 'image_url': {'url': f'data:image/{image_format};base64,{base64_image}'}}) messages = [HumanMessage( content=[ {'type': 'text', 'text': self.workflow_manage.generate_prompt(prompt)}, diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/image.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/image.py new file mode 100644 index 000000000..ff31b5ef0 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/image.py @@ -0,0 +1,63 @@ +# coding=utf-8 +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + +class VolcanicEngineImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.95, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=1024, + _min=1, + _max=100000, + _step=1, + precision=0) + +class VolcanicEngineImageModelCredential(BaseForm, BaseModelCredential): + api_key = forms.PasswordInputField('API Key', required=True) + api_base = forms.TextInputField('API 域名', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_key', 'api_base']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) + for chunk in res: + print(chunk) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return VolcanicEngineImageModelParams() diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py new file mode 100644 index 000000000..3c980778d --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tti.py @@ -0,0 +1,62 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VolcanicEngineTTIModelGeneralParams(BaseForm): + size = forms.SingleSelect( + TooltipLabel('图片尺寸', + '宽、高与512差距过大,则出图效果不佳、延迟过长概率显著增加。超分前建议比例及对应宽高:width*height'), + required=True, + default_value='512*512', + option_list=[ + {'value': '512*512', 'label': '512*512'}, + {'value': '512*384', 'label': '512*384'}, + {'value': '384*512', 'label': '384*512'}, + {'value': '512*341', 'label': '512*341'}, + {'value': '341*512', 'label': '341*512'}, + {'value': '512*288', 'label': '512*288'}, + {'value': '288*512', 'label': '288*512'}, + ], + text_field='label', + value_field='value') + + +class VolcanicEngineTTIModelCredential(BaseForm, BaseModelCredential): + access_key = forms.PasswordInputField('Access Key', required=True) + secret_key = forms.PasswordInputField('Secret Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['access_key', 'secret_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'secret_key': super().encryption(model.get('secret_key', ''))} + + def get_model_params_setting_form(self, model_name): + return VolcanicEngineTTIModelGeneralParams() diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py new file mode 100644 index 000000000..3cc467611 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/image.py @@ -0,0 +1,26 @@ +from typing import Dict + +from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class VolcanicEngineImage(MaxKBBaseModel, ChatOpenAI): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return VolcanicEngineImage( + model_name=model_name, + openai_api_key=model_credential.get('api_key'), + openai_api_base=model_credential.get('api_base'), + # stream_options={"include_usage": True}, + streaming=True, + **optional_params, + ) diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py new file mode 100644 index 000000000..eccfda259 --- /dev/null +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tti.py @@ -0,0 +1,173 @@ +# coding=utf-8 + +''' +requires Python 3.6 or later + +pip install asyncio +pip install websockets + +''' + +import datetime +import hashlib +import hmac +import json +import sys +from typing import Dict + +import requests +from langchain_openai import ChatOpenAI + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + +method = 'POST' +host = 'visual.volcengineapi.com' +region = 'cn-north-1' +endpoint = 'https://visual.volcengineapi.com' +service = 'cv' + +req_key_dict = { + 'general_v1.4': 'high_aes_general_v14', + 'general_v2.0': 'high_aes_general_v20', + 'general_v2.0_L': 'high_aes_general_v20_L', + 'anime_v1.3': 'high_aes', + 'anime_v1.3.1': 'high_aes', +} + + +def sign(key, msg): + return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest() + + +def getSignatureKey(key, dateStamp, regionName, serviceName): + kDate = sign(key.encode('utf-8'), dateStamp) + kRegion = sign(kDate, regionName) + kService = sign(kRegion, serviceName) + kSigning = sign(kService, 'request') + return kSigning + + +def formatQuery(parameters): + request_parameters_init = '' + for key in sorted(parameters): + request_parameters_init += key + '=' + parameters[key] + '&' + request_parameters = request_parameters_init[:-1] + return request_parameters + + +def signV4Request(access_key, secret_key, service, req_query, req_body): + if access_key is None or secret_key is None: + print('No access key is available.') + sys.exit() + + t = datetime.datetime.utcnow() + current_date = t.strftime('%Y%m%dT%H%M%SZ') + # current_date = '20210818T095729Z' + datestamp = t.strftime('%Y%m%d') # Date w/o time, used in credential scope + canonical_uri = '/' + canonical_querystring = req_query + signed_headers = 'content-type;host;x-content-sha256;x-date' + payload_hash = hashlib.sha256(req_body.encode('utf-8')).hexdigest() + content_type = 'application/json' + canonical_headers = 'content-type:' + content_type + '\n' + 'host:' + host + \ + '\n' + 'x-content-sha256:' + payload_hash + \ + '\n' + 'x-date:' + current_date + '\n' + canonical_request = method + '\n' + canonical_uri + '\n' + canonical_querystring + \ + '\n' + canonical_headers + '\n' + signed_headers + '\n' + payload_hash + # print(canonical_request) + algorithm = 'HMAC-SHA256' + credential_scope = datestamp + '/' + region + '/' + service + '/' + 'request' + string_to_sign = algorithm + '\n' + current_date + '\n' + credential_scope + '\n' + hashlib.sha256( + canonical_request.encode('utf-8')).hexdigest() + # print(string_to_sign) + signing_key = getSignatureKey(secret_key, datestamp, region, service) + # print(signing_key) + signature = hmac.new(signing_key, (string_to_sign).encode( + 'utf-8'), hashlib.sha256).hexdigest() + # print(signature) + + authorization_header = algorithm + ' ' + 'Credential=' + access_key + '/' + \ + credential_scope + ', ' + 'SignedHeaders=' + \ + signed_headers + ', ' + 'Signature=' + signature + # print(authorization_header) + headers = {'X-Date': current_date, + 'Authorization': authorization_header, + 'X-Content-Sha256': payload_hash, + 'Content-Type': content_type + } + # print(headers) + + # ************* SEND THE REQUEST ************* + request_url = endpoint + '?' + canonical_querystring + + print('\nBEGIN REQUEST++++++++++++++++++++++++++++++++++++') + print('Request URL = ' + request_url) + try: + r = requests.post(request_url, headers=headers, data=req_body) + except Exception as err: + print(f'error occurred: {err}') + raise + else: + print('\nRESPONSE++++++++++++++++++++++++++++++++++++') + print(f'Response code: {r.status_code}\n') + # 使用 replace 方法将 \u0026 替换为 & + resp_str = r.text.replace("\\u0026", "&") + if r.status_code != 200: + raise Exception(f'Error: {resp_str}') + print(f'Response body: {resp_str}\n') + return json.loads(resp_str)['data']['image_urls'] + + +class VolcanicEngineTextToImage(MaxKBBaseModel, BaseTextToImage): + access_key: str + secret_key: str + model_version: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.access_key = kwargs.get('access_key') + self.secret_key = kwargs.get('secret_key') + self.model_version = kwargs.get('model_version') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return VolcanicEngineTextToImage( + model_version=model_name, + access_key=model_credential.get('access_key'), + secret_key=model_credential.get('secret_key'), + **optional_params + ) + + def check_auth(self): + res = self.generate_image('生成一张小猫图片') + print(res) + + def generate_image(self, prompt: str, negative_prompt: str = None): + # 请求Query,按照接口文档中填入即可 + query_params = { + 'Action': 'CVProcess', + 'Version': '2022-08-31', + } + formatted_query = formatQuery(query_params) + size = self.params.pop('size', '512*512').split('*') + body_params = { + "req_key": req_key_dict[self.model_version], + "prompt": prompt, + "model_version": self.model_version, + "return_url": True, + "width": int(size[0]), + "height": int(size[1]), + **self.params + } + formatted_body = json.dumps(body_params) + return signV4Request(self.access_key, self.secret_key, service, formatted_query, formatted_body) + + def is_cache_model(self): + return False diff --git a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py index 1a0e17d8b..2223d1ccb 100644 --- a/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py +++ b/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -14,10 +14,15 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel +from setting.models_provider.impl.volcanic_engine_model_provider.credential.image import \ + VolcanicEngineImageModelCredential +from setting.models_provider.impl.volcanic_engine_model_provider.credential.tti import VolcanicEngineTTIModelCredential from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential +from setting.models_provider.impl.volcanic_engine_model_provider.model.image import VolcanicEngineImage from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText +from setting.models_provider.impl.volcanic_engine_model_provider.model.tti import VolcanicEngineTextToImage from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech from smartdoc.conf import PROJECT_DIR @@ -25,6 +30,8 @@ from smartdoc.conf import PROJECT_DIR volcanic_engine_llm_model_credential = OpenAILLMModelCredential() volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential() volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential() +volcanic_engine_image_model_credential = VolcanicEngineImageModelCredential() +volcanic_engine_tti_model_credential = VolcanicEngineTTIModelCredential() model_info_list = [ ModelInfo('ep-xxxxxxxxxx-yyyy', @@ -32,6 +39,11 @@ model_info_list = [ ModelTypeConst.LLM, volcanic_engine_llm_model_credential, VolcanicEngineChatModel ), + ModelInfo('ep-xxxxxxxxxx-yyyy', + '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', + ModelTypeConst.IMAGE, + volcanic_engine_image_model_credential, VolcanicEngineImage + ), ModelInfo('asr', '', ModelTypeConst.STT, @@ -42,6 +54,31 @@ model_info_list = [ ModelTypeConst.TTS, volcanic_engine_tts_model_credential, VolcanicEngineTextToSpeech ), + ModelInfo('general_v2.0', + '通用2.0-文生图', + ModelTypeConst.TTI, + volcanic_engine_tti_model_credential, VolcanicEngineTextToImage + ), + ModelInfo('general_v2.0_L', + '通用2.0Pro-文生图', + ModelTypeConst.TTI, + volcanic_engine_tti_model_credential, VolcanicEngineTextToImage + ), + ModelInfo('general_v1.4', + '通用1.4-文生图', + ModelTypeConst.TTI, + volcanic_engine_tti_model_credential, VolcanicEngineTextToImage + ), + ModelInfo('anime_v1.3', + '动漫1.3.0-文生图', + ModelTypeConst.TTI, + volcanic_engine_tti_model_credential, VolcanicEngineTextToImage + ), + ModelInfo('anime_v1.3.1', + '动漫1.3.1-文生图', + ModelTypeConst.TTI, + volcanic_engine_tti_model_credential, VolcanicEngineTextToImage + ), ] open_ai_embedding_credential = OpenAIEmbeddingCredential() @@ -51,8 +88,13 @@ model_info_embedding_list = [ ModelTypeConst.EMBEDDING, open_ai_embedding_credential, OpenAIEmbeddingModel)] -model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( - model_info_list[0]).build() +model_info_manage = ( + ModelInfoManage.builder() + .append_model_info_list(model_info_list) + .append_default_model_info(model_info_list[0]) + .append_default_model_info(model_info_list[1]) + .build() +) class VolcanicEngineModelProvider(IModelProvider): diff --git a/ui/src/workflow/nodes/image-understand/index.vue b/ui/src/workflow/nodes/image-understand/index.vue index ecde49434..40f23847c 100644 --- a/ui/src/workflow/nodes/image-understand/index.vue +++ b/ui/src/workflow/nodes/image-understand/index.vue @@ -25,6 +25,15 @@