diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/image.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/image.py new file mode 100644 index 000000000..e2cbbb194 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/image.py @@ -0,0 +1,65 @@ +# coding=utf-8 +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + +class XinferenceImageModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + + +class XinferenceImageModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])]) + for chunk in res: + print(chunk) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return XinferenceImageModelParams() diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/tti.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/tti.py new file mode 100644 index 000000000..eba50d022 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/tti.py @@ -0,0 +1,82 @@ +# coding=utf-8 +import base64 +import os +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XinferenceTTIModelParams(BaseForm): + size = forms.SingleSelect( + TooltipLabel('图片尺寸', '指定生成图片的尺寸, 如: 1024x1024'), + required=True, + default_value='1024x1024', + option_list=[ + {'value': '1024x1024', 'label': '1024x1024'}, + {'value': '1024x1792', 'label': '1024x1792'}, + {'value': '1792x1024', 'label': '1792x1024'}, + ], + text_field='label', + value_field='value' + ) + + quality = forms.SingleSelect( + TooltipLabel('图片质量', ''), + required=True, + default_value='standard', + option_list=[ + {'value': 'standard', 'label': 'standard'}, + {'value': 'hd', 'label': 'hd'}, + ], + text_field='label', + value_field='value' + ) + + n = forms.SliderField( + TooltipLabel('图片数量', '指定生成图片的数量'), + required=True, default_value=1, + _min=1, + _max=10, + _step=1, + precision=0) + + +class XinferenceTextToImageModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.check_auth() + print(res) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return XinferenceTTIModelParams() diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/image.py b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py new file mode 100644 index 000000000..1b696b8cf --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/image.py @@ -0,0 +1,26 @@ +from typing import Dict + +from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class XinferenceImage(MaxKBBaseModel, ChatOpenAI): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return XinferenceImage( + model_name=model_name, + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + # stream_options={"include_usage": True}, + streaming=True, + **optional_params, + ) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/tti.py b/apps/setting/models_provider/impl/xinference_model_provider/model/tti.py new file mode 100644 index 000000000..ee5b655f4 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/tti.py @@ -0,0 +1,66 @@ +import base64 +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from common.util.common import bytes_to_uploaded_file +from dataset.serializers.file_serializers import FileSerializer +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tti import BaseTextToImage + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class XinferenceTextToImage(MaxKBBaseModel, BaseTextToImage): + api_base: str + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'size': '1024x1024', 'quality': 'standard', 'n': 1}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return XinferenceTextToImage( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def is_cache_model(self): + return False + + def check_auth(self): + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + response_list = chat.models.with_raw_response.list() + + # self.generate_image('生成一个小猫图片') + + def generate_image(self, prompt: str, negative_prompt: str = None): + chat = OpenAI(api_key=self.api_key, base_url=self.api_base) + res = chat.images.generate(model=self.model, prompt=prompt, response_format='b64_json', **self.params) + file_urls = [] + # 临时文件 + for img in res.data: + file = bytes_to_uploaded_file(base64.b64decode(img.b64_json), 'file_name.jpg') + meta = { + 'debug': True, + } + file_url = FileSerializer(data={'file': file, 'meta': meta}).upload() + file_urls.append(f'http://localhost:8080{file_url}') + + return file_urls diff --git a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py index 0da07f6d3..4c7b9c0c7 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py @@ -9,20 +9,26 @@ from setting.models_provider.base_model_provider import IModelProvider, ModelPro ModelInfoManage from setting.models_provider.impl.xinference_model_provider.credential.embedding import \ XinferenceEmbeddingModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.image import XinferenceImageModelCredential from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential from setting.models_provider.impl.xinference_model_provider.credential.reranker import XInferenceRerankerModelCredential from setting.models_provider.impl.xinference_model_provider.credential.stt import XInferenceSTTModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.tti import XinferenceTextToImageModelCredential from setting.models_provider.impl.xinference_model_provider.credential.tts import XInferenceTTSModelCredential from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding +from setting.models_provider.impl.xinference_model_provider.model.image import XinferenceImage from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel from setting.models_provider.impl.xinference_model_provider.model.reranker import XInferenceReranker from setting.models_provider.impl.xinference_model_provider.model.stt import XInferenceSpeechToText +from setting.models_provider.impl.xinference_model_provider.model.tti import XinferenceTextToImage from setting.models_provider.impl.xinference_model_provider.model.tts import XInferenceTextToSpeech from smartdoc.conf import PROJECT_DIR xinference_llm_model_credential = XinferenceLLMModelCredential() xinference_stt_model_credential = XInferenceSTTModelCredential() xinference_tts_model_credential = XInferenceTTSModelCredential() +xinference_image_model_credential = XinferenceImageModelCredential() +xinference_tti_model_credential = XinferenceTextToImageModelCredential() model_info_list = [ ModelInfo( @@ -296,6 +302,159 @@ voice_model_info = [ ), ] +image_model_info = [ + ModelInfo( + 'qwen-vl-chat', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'deepseek-vl-chat', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'yi-vl-chat', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'omnilmm', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'internvl-chat', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'cogvlm2', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'MiniCPM-Llama3-V-2_5', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'GLM-4V', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'MiniCPM-V-2.6', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'internvl2', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'qwen2-vl-instruct', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'llama-3.2-vision', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'llama-3.2-vision-instruct', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), + ModelInfo( + 'glm-edge-v', + '', + ModelTypeConst.IMAGE, + xinference_image_model_credential, + XinferenceImage + ), +] + +tti_model_info = [ + ModelInfo( + 'sd-turbo', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'sdxl-turbo', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'stable-diffusion-v1.5', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'stable-diffusion-xl-base-1.0', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'sd3-medium', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'FLUX.1-schnell', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), + ModelInfo( + 'FLUX.1-dev', + '', + ModelTypeConst.TTI, + xinference_tti_model_credential, + XinferenceTextToImage + ), +] + xinference_embedding_model_credential = XinferenceEmbeddingModelCredential() # 生成embedding_model_info列表 @@ -377,6 +536,8 @@ model_info_manage = (ModelInfoManage.builder() ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, XinferenceEmbedding)) .append_model_info_list(rerank_list) + .append_model_info_list(image_model_info) + .append_model_info_list(tti_model_info) .append_default_model_info(rerank_list[0]) .build())