feat: 讯飞图片模型

This commit is contained in:
CaptainB 2024-11-04 12:31:36 +08:00 committed by 刘瑞斌
parent f318f2da40
commit ddad340534
9 changed files with 240 additions and 4 deletions

View File

@ -149,6 +149,7 @@ class ModelTypeConst(Enum):
EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'}
STT = {'code': 'STT', 'message': '语音识别'}
TTS = {'code': 'TTS', 'message': '语音合成'}
IMAGE = {'code': 'IMAGE', 'message': '图片理解'}
RERANKER = {'code': 'RERANKER', 'message': '重排模型'}

View File

@ -0,0 +1,14 @@
# coding=utf-8
from abc import abstractmethod
from pydantic import BaseModel
class BaseImage(BaseModel):
@abstractmethod
def check_auth(self):
pass
@abstractmethod
def image_understand(self, image_file, text):
pass

View File

@ -0,0 +1,46 @@
# coding=utf-8
from typing import Dict
from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
class XunFeiImageModelCredential(BaseForm, BaseModelCredential):
spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://spark-api.cn-huabei-1.xf-yun.com/v2.1/image')
spark_app_id = forms.TextInputField('APP ID', required=True)
spark_api_key = forms.PasswordInputField("API Key", required=True)
spark_api_secret = forms.PasswordInputField('API Secret', required=True)
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']:
if key not in model_credential:
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
else:
return False
try:
model = provider.get_model(model_type, model_name, model_credential)
model.check_auth()
except Exception as e:
if isinstance(e, AppApiException):
raise e
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
else:
return False
return True
def encryption_dict(self, model: Dict[str, object]):
return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))}
def get_model_params_setting_form(self, model_name):
pass

View File

@ -0,0 +1,170 @@
# coding=utf-8
import asyncio
import base64
import datetime
import hashlib
import hmac
import json
import os
import ssl
from datetime import datetime, UTC
from typing import Dict
from urllib.parse import urlencode
from urllib.parse import urlparse
import websockets
from setting.models_provider.base_model_provider import MaxKBBaseModel
from setting.models_provider.impl.base_image import BaseImage
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
class XFSparkImage(MaxKBBaseModel, BaseImage):
spark_app_id: str
spark_api_key: str
spark_api_secret: str
spark_api_url: str
params: dict
# 初始化
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.spark_api_url = kwargs.get('spark_api_url')
self.spark_app_id = kwargs.get('spark_app_id')
self.spark_api_key = kwargs.get('spark_api_key')
self.spark_api_secret = kwargs.get('spark_api_secret')
self.params = kwargs.get('params')
@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
optional_params = {'params': {}}
for key, value in model_kwargs.items():
if key not in ['model_id', 'use_local', 'streaming']:
optional_params['params'][key] = value
return XFSparkImage(
spark_app_id=model_credential.get('spark_app_id'),
spark_api_key=model_credential.get('spark_api_key'),
spark_api_secret=model_credential.get('spark_api_secret'),
spark_api_url=model_credential.get('spark_api_url'),
**optional_params
)
def create_url(self):
url = self.spark_api_url
host = urlparse(url).hostname
# 生成RFC1123格式的时间戳
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
date = datetime.now(UTC).strftime(gmt_format)
# 拼接字符串
signature_origin = "host: " + host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + "/v2.1/image " + "HTTP/1.1"
# 进行hmac-sha256进行加密
signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
digestmod=hashlib.sha256).digest()
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha)
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
# 将请求的鉴权参数组合为字典
v = {
"authorization": authorization,
"date": date,
"host": host
}
# 拼接鉴权参数生成url
url = url + '?' + urlencode(v)
# print("date: ",date)
# print("v: ",v)
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释比对相同参数时生成的url与自己代码生成的url是否一致
# print('websocket url :', url)
return url
def check_auth(self):
cwd = os.path.dirname(os.path.abspath(__file__))
with open(f'{cwd}/img_1.png', 'rb') as f:
self.image_understand(f,"一句话概述这个图片")
def image_understand(self, image_file, question):
async def handle():
async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws:
# 发送 full client request
await self.send(ws, image_file, question)
return await self.handle_message(ws)
return asyncio.run(handle())
# 收到websocket消息的处理
@staticmethod
async def handle_message(ws):
# print(message)
answer = ''
while True:
res = await ws.recv()
data = json.loads(res)
code = data['header']['code']
if code != 0:
return f'请求错误: {code}, {data}'
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
# print(content, end="")
answer += content
# print(1)
if status == 2:
break
return answer
async def send(self, ws, image_file, question):
text = [
{"role": "user", "content": str(base64.b64encode(image_file.read()), 'utf-8'), "content_type": "image"},
{"role": "user", "content": question}
]
data = {
"header": {
"app_id": self.spark_app_id
},
"parameter": {
"chat": {
"domain": "image",
"temperature": 0.5,
"top_k": 4,
"max_tokens": 2028,
"auditing": "default"
}
},
"payload": {
"message": {
"text": text
}
}
}
d = json.dumps(data)
await ws.send(d)
def is_cache_model(self):
return False
@staticmethod
def get_len(text):
length = 0
for content in text:
temp = content["content"]
leng = len(temp)
length += leng
return length
def check_len(self, text):
print("text-content-tokens:", self.get_len(text[1:]))
while (self.get_len(text[1:]) > 8000):
del text[1]
return text

Binary file not shown.

After

Width:  |  Height:  |  Size: 354 KiB

View File

@ -10,7 +10,7 @@ import hmac
import json
import logging
import os
from datetime import datetime
from datetime import datetime, UTC
from typing import Dict
from urllib.parse import urlencode, urlparse
import ssl
@ -63,7 +63,7 @@ class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText):
host = urlparse(url).hostname
# 生成RFC1123格式的时间戳
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
date = datetime.utcnow().strftime(gmt_format)
date = datetime.now(UTC).strftime(gmt_format)
# 拼接字符串
signature_origin = "host: " + host + "\n"

View File

@ -12,7 +12,7 @@ import hmac
import json
import logging
import os
from datetime import datetime
from datetime import datetime, UTC
from typing import Dict
from urllib.parse import urlencode, urlparse
import ssl
@ -67,7 +67,7 @@ class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech):
host = urlparse(url).hostname
# 生成RFC1123格式的时间戳
gmt_format = '%a, %d %b %Y %H:%M:%S GMT'
date = datetime.utcnow().strftime(gmt_format)
date = datetime.now(UTC).strftime(gmt_format)
# 拼接字符串
signature_origin = "host: " + host + "\n"

View File

@ -13,10 +13,12 @@ from common.util.file_util import get_file_content
from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \
ModelInfoManage
from setting.models_provider.impl.xf_model_provider.credential.embedding import XFEmbeddingCredential
from setting.models_provider.impl.xf_model_provider.credential.image import XunFeiImageModelCredential
from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential
from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential
from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential
from setting.models_provider.impl.xf_model_provider.model.embedding import XFEmbedding
from setting.models_provider.impl.xf_model_provider.model.image import XFSparkImage
from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM
from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText
from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech
@ -26,6 +28,7 @@ ssl._create_default_https_context = ssl.create_default_context()
qwen_model_credential = XunFeiLLMModelCredential()
stt_model_credential = XunFeiSTTModelCredential()
image_model_credential = XunFeiImageModelCredential()
tts_model_credential = XunFeiTTSModelCredential()
embedding_model_credential = XFEmbeddingCredential()
model_info_list = [
@ -34,6 +37,7 @@ model_info_list = [
ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM),
ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText),
ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech),
ModelInfo('image', '', ModelTypeConst.IMAGE, image_model_credential, XFSparkImage),
ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding)
]

View File

@ -132,6 +132,7 @@
<el-option label="重排模型" value="RERANKER" />
<el-option label="语音识别" value="STT" />
<el-option label="语音合成" value="TTS" />
<el-option label="图片理解" value="IMAGE" />
</el-select>
</div>
</div>