From 1802b58a74d7af3600850536c8ec98f6f98a4710 Mon Sep 17 00:00:00 2001 From: CaptainB Date: Thu, 10 Oct 2024 13:28:11 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0xinference=E8=AF=AD?= =?UTF-8?q?=E9=9F=B3=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../credential/stt.py | 42 +++++++++++++ .../xinference_model_provider/model/stt.py | 59 ++++++++++++++++++ .../xinference_model_provider/model/tts.py | 60 +++++++++++++++++++ .../xinference_model_provider.py | 19 ++++++ 4 files changed, 180 insertions(+) create mode 100644 apps/setting/models_provider/impl/xinference_model_provider/credential/stt.py create mode 100644 apps/setting/models_provider/impl/xinference_model_provider/model/stt.py create mode 100644 apps/setting/models_provider/impl/xinference_model_provider/model/tts.py diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/stt.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/stt.py new file mode 100644 index 000000000..7d19feaef --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/stt.py @@ -0,0 +1,42 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XInferenceSTTModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/stt.py b/apps/setting/models_provider/impl/xinference_model_provider/model/stt.py new file mode 100644 index 000000000..5e21ca6f9 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/stt.py @@ -0,0 +1,59 @@ +import asyncio +import io +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_stt import BaseSpeechToText + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class XInferenceSpeechToText(MaxKBBaseModel, BaseSpeechToText): + api_base: str + api_key: str + model: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return XInferenceSpeechToText( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + response_list = client.models.with_raw_response.list() + # print(response_list) + + def speech_to_text(self, audio_file): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + audio_data = audio_file.read() + buffer = io.BytesIO(audio_data) + buffer.name = "file.mp3" # this is the important line + res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer) + return res.text + diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/tts.py b/apps/setting/models_provider/impl/xinference_model_provider/model/tts.py new file mode 100644 index 000000000..6e6e46aa5 --- /dev/null +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/tts.py @@ -0,0 +1,60 @@ +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tts import BaseTextToSpeech + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class XInferenceTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): + api_base: str + api_key: str + model: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return XInferenceTextToSpeech( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + response_list = client.models.with_raw_response.list() + # print(response_list) + + def text_to_speech(self, text): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + # ['中文女', '中文男', '日语男', '粤语女', '英文女', '英文男', '韩语女'] + + with client.audio.speech.with_streaming_response.create( + model=self.model, + voice="中文女", + input=text, + ) as response: + return response.read() diff --git a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py index 0fbc3cc32..06c199829 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py @@ -11,12 +11,17 @@ from setting.models_provider.impl.xinference_model_provider.credential.embedding XinferenceEmbeddingModelCredential from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential from setting.models_provider.impl.xinference_model_provider.credential.reranker import XInferenceRerankerModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.stt import XInferenceSTTModelCredential from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel from setting.models_provider.impl.xinference_model_provider.model.reranker import XInferenceReranker +from setting.models_provider.impl.xinference_model_provider.model.stt import XInferenceSpeechToText +from setting.models_provider.impl.xinference_model_provider.model.tts import XInferenceTextToSpeech from smartdoc.conf import PROJECT_DIR xinference_llm_model_credential = XinferenceLLMModelCredential() +xinference_stt_model_credential = XInferenceSTTModelCredential() + model_info_list = [ ModelInfo( 'code-llama', @@ -270,6 +275,20 @@ model_info_list = [ xinference_llm_model_credential, XinferenceChatModel ), + ModelInfo( + 'CosyVoice-300M-SFT', + 'CosyVoice-300M-SFT是一个小型的语音合成模型。', + ModelTypeConst.TTS, + xinference_stt_model_credential, + XInferenceTextToSpeech + ), + ModelInfo( + 'Belle-whisper-large-v3-zh', + 'Belle Whisper Large V3 是一个中文大型语音识别模型。', + ModelTypeConst.STT, + xinference_stt_model_credential, + XInferenceSpeechToText + ), ] xinference_embedding_model_credential = XinferenceEmbeddingModelCredential()