diff --git a/apps/common/chunk/impl/mark_chunk_handle.py b/apps/common/chunk/impl/mark_chunk_handle.py index 9e5a66c0e..4f8623c9c 100644 --- a/apps/common/chunk/impl/mark_chunk_handle.py +++ b/apps/common/chunk/impl/mark_chunk_handle.py @@ -11,27 +11,25 @@ from typing import List from common.chunk.i_chunk_handle import IChunkHandle -split_chunk_pattern = "!|。|\n|;|;" -min_chunk_len = 20 +max_chunk_len = 256 +split_chunk_pattern = r'.{1,%d}[。| |\\.|!|;|;|!|\n]' % max_chunk_len +max_chunk_pattern = r'.{1,%d}' % max_chunk_len class MarkChunkHandle(IChunkHandle): def handle(self, chunk_list: List[str]): result = [] for chunk in chunk_list: - base_chunk = re.split(split_chunk_pattern, chunk) - base_chunk = [chunk.strip() for chunk in base_chunk if len(chunk.strip()) > 0] - result_chunk = [] - for c in base_chunk: - if len(result_chunk) == 0: - result_chunk.append(c) - else: - if len(result_chunk[-1]) < min_chunk_len: - result_chunk[-1] = result_chunk[-1] + c + chunk_result = re.findall(split_chunk_pattern, chunk, flags=re.DOTALL) + for c_r in chunk_result: + result.append(c_r) + other_chunk_list = re.split(split_chunk_pattern, chunk, flags=re.DOTALL) + for other_chunk in other_chunk_list: + if len(other_chunk) > 0: + if len(other_chunk) < max_chunk_len: + result.append(other_chunk) else: - if len(c) < min_chunk_len: - result_chunk[-1] = result_chunk[-1] + c - else: - result_chunk.append(c) - result = [*result, *result_chunk] + max_chunk_list = re.findall(max_chunk_pattern, other_chunk, flags=re.DOTALL) + for m_c in max_chunk_list: + result.append(m_c) return result diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py index fffc7e68d..f3fd75a80 100644 --- a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py @@ -11,10 +11,13 @@ import os from common.util.file_util import get_file_content from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ ModelInfoManage +from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.embedding import \ + AliyunBaiLianEmbeddingCredential from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.reranker import \ AliyunBaiLianRerankerCredential from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.stt import AliyunBaiLianSTTModelCredential from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.tts import AliyunBaiLianTTSModelCredential +from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.stt import AliyunBaiLianSpeechToText from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tts import AliyunBaiLianTextToSpeech @@ -23,6 +26,7 @@ from smartdoc.conf import PROJECT_DIR aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential() aliyun_bai_lian_tts_model_credential = AliyunBaiLianTTSModelCredential() aliyun_bai_lian_stt_model_credential = AliyunBaiLianSTTModelCredential() +aliyun_bai_lian_embedding_model_credential = AliyunBaiLianEmbeddingCredential() model_info_list = [ModelInfo('gte-rerank', '阿里巴巴通义实验室开发的GTE-Rerank文本排序系列模型,开发者可以通过LlamaIndex框架进行集成高质量文本检索、排序。', @@ -33,10 +37,15 @@ model_info_list = [ModelInfo('gte-rerank', ModelInfo('cosyvoice-v1', 'CosyVoice基于新一代生成式语音大模型,能根据上下文预测情绪、语调、韵律等,具有更好的拟人效果', ModelTypeConst.TTS, aliyun_bai_lian_tts_model_credential, AliyunBaiLianTextToSpeech), + ModelInfo('text-embedding-v1', + '通用文本向量,是通义实验室基于LLM底座的多语言文本统一向量模型,面向全球多个主流语种,提供高水准的向量服务,帮助开发者将文本数据快速转换为高质量的向量数据。', + ModelTypeConst.EMBEDDING, aliyun_bai_lian_embedding_model_credential, + AliyunBaiLianEmbedding), ] model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( - model_info_list[1]).append_default_model_info(model_info_list[2]).build() + model_info_list[1]).append_default_model_info(model_info_list[2]).append_default_model_info( + model_info_list[3]).build() class AliyunBaiLianModelProvider(IModelProvider): diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py new file mode 100644 index 000000000..ba53555f0 --- /dev/null +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py @@ -0,0 +1,46 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/10/16 17:01 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential +from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding + + +class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['dashscope_api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model: AliyunBaiLianEmbedding = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return model + + dashscope_api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py new file mode 100644 index 000000000..e209e770b --- /dev/null +++ b/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py @@ -0,0 +1,54 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/10/16 16:34 + @desc: +""" +from typing import Dict, List + +from langchain_community.embeddings import DashScopeEmbeddings +from langchain_community.embeddings.dashscope import embed_with_retry + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return AliyunBaiLianEmbedding( + model=model_name, + dashscope_api_key=model_credential.get('dashscope_api_key') + ) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Call out to DashScope's embedding endpoint for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + embeddings = embed_with_retry( + self, input=texts, text_type="document", model=self.model + ) + embedding_list = [item["embedding"] for item in embeddings] + return embedding_list + + def embed_query(self, text: str) -> List[float]: + """Call out to DashScope's embedding endpoint for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + embedding = embed_with_retry( + self, input=[text], text_type="document", model=self.model + )[0]["embedding"] + return embedding