mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-27 20:42:52 +00:00
51 lines
1.5 KiB
Python
51 lines
1.5 KiB
Python
# coding=utf-8
|
||
"""
|
||
@project: MaxKB
|
||
@Author:虎
|
||
@file: embedding.py
|
||
@date:2024/10/17 15:29
|
||
@desc:
|
||
"""
|
||
|
||
import base64
|
||
import json
|
||
from typing import Dict, Optional
|
||
|
||
import numpy as np
|
||
from langchain_community.embeddings import SparkLLMTextEmbeddings
|
||
from numpy import ndarray
|
||
|
||
from models_provider.base_model_provider import MaxKBBaseModel
|
||
|
||
|
||
class XFEmbedding(MaxKBBaseModel, SparkLLMTextEmbeddings):
|
||
@staticmethod
|
||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
|
||
return XFEmbedding(
|
||
base_url=model_credential.get('base_url'),
|
||
spark_app_id=model_credential.get('spark_app_id'),
|
||
spark_api_key=model_credential.get('spark_api_key'),
|
||
spark_api_secret=model_credential.get('spark_api_secret')
|
||
)
|
||
|
||
@staticmethod
|
||
def _parser_message(
|
||
message: str,
|
||
) -> Optional[ndarray]:
|
||
data = json.loads(message)
|
||
code = data["header"]["code"]
|
||
if code != 0:
|
||
# 这里是讯飞的QPS限制会报错,所以不建议用讯飞的向量模型
|
||
raise Exception(f"Request error: {code}, {data}")
|
||
else:
|
||
text_base = data["payload"]["feature"]["text"]
|
||
text_data = base64.b64decode(text_base)
|
||
dt = np.dtype(np.float32)
|
||
dt = dt.newbyteorder("<")
|
||
text = np.frombuffer(text_data, dtype=dt)
|
||
if len(text) > 2560:
|
||
array = text[:2560]
|
||
else:
|
||
array = text
|
||
return array
|