diff --git a/apps/common/chunk/__init__.py b/apps/common/chunk/__init__.py new file mode 100644 index 000000000..a4babde76 --- /dev/null +++ b/apps/common/chunk/__init__.py @@ -0,0 +1,18 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/7/23 17:03 + @desc: +""" +from common.chunk.impl.mark_chunk_handle import MarkChunkHandle + +handles = [MarkChunkHandle()] + + +def text_to_chunk(text: str): + chunk_list = [text] + for handle in handles: + chunk_list = handle.handle(chunk_list) + return chunk_list diff --git a/apps/common/chunk/i_chunk_handle.py b/apps/common/chunk/i_chunk_handle.py new file mode 100644 index 000000000..d53575d11 --- /dev/null +++ b/apps/common/chunk/i_chunk_handle.py @@ -0,0 +1,16 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_chunk_handle.py + @date:2024/7/23 16:51 + @desc: +""" +from abc import ABC, abstractmethod +from typing import List + + +class IChunkHandle(ABC): + @abstractmethod + def handle(self, chunk_list: List[str]): + pass diff --git a/apps/common/chunk/impl/mark_chunk_handle.py b/apps/common/chunk/impl/mark_chunk_handle.py new file mode 100644 index 000000000..f86290a1f --- /dev/null +++ b/apps/common/chunk/impl/mark_chunk_handle.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: mark_chunk_handle.py + @date:2024/7/23 16:52 + @desc: +""" +import re +from typing import List + +from common.chunk.i_chunk_handle import IChunkHandle + +split_chunk_pattern = "!|。|\n|;|;" + + +class MarkChunkHandle(IChunkHandle): + def handle(self, chunk_list: List[str]): + result = [] + for chunk in chunk_list: + base_chunk = re.split(split_chunk_pattern, chunk) + base_chunk = [chunk.strip() for chunk in base_chunk if len(chunk.strip()) > 0] + result = [*result, *base_chunk] + return result diff --git a/apps/common/sql/list_embedding_text.sql b/apps/common/sql/list_embedding_text.sql index 74f3b224b..ac0dc7b31 100644 --- a/apps/common/sql/list_embedding_text.sql +++ b/apps/common/sql/list_embedding_text.sql @@ -19,9 +19,7 @@ SELECT paragraph."id" AS paragraph_id, paragraph.dataset_id AS dataset_id, 1 AS source_type, - concat_ws(' -',concat_ws(' -',paragraph.title,paragraph."content"),paragraph.title) AS "text", + concat_ws(E'\n',paragraph.title,paragraph."content") AS "text", paragraph.is_active AS is_active FROM paragraph paragraph diff --git a/apps/embedding/migrations/0003_alter_embedding_unique_together.py b/apps/embedding/migrations/0003_alter_embedding_unique_together.py new file mode 100644 index 000000000..9cb45061b --- /dev/null +++ b/apps/embedding/migrations/0003_alter_embedding_unique_together.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.14 on 2024-07-23 18:14 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('embedding', '0002_embedding_search_vector'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='embedding', + unique_together=set(), + ), + ] diff --git a/apps/embedding/models/embedding.py b/apps/embedding/models/embedding.py index 24c78f41f..5f954e36b 100644 --- a/apps/embedding/models/embedding.py +++ b/apps/embedding/models/embedding.py @@ -50,4 +50,3 @@ class Embedding(models.Model): class Meta: db_table = "embedding" - unique_together = ['source_id', 'source_type'] diff --git a/apps/embedding/vector/base_vector.py b/apps/embedding/vector/base_vector.py index 281ca12f2..15a4449b5 100644 --- a/apps/embedding/vector/base_vector.py +++ b/apps/embedding/vector/base_vector.py @@ -8,16 +8,31 @@ """ import threading from abc import ABC, abstractmethod +from functools import reduce from typing import List, Dict from langchain_core.embeddings import Embeddings +from common.chunk import text_to_chunk from common.util.common import sub_array from embedding.models import SourceType, SearchMode lock = threading.Lock() +def chunk_data(data: Dict): + if str(data.get('source_type')) == SourceType.PARAGRAPH.value: + text = data.get('text') + chunk_list = text_to_chunk(text) + return [{**data, 'text': chunk} for chunk in chunk_list] + return [data] + + +def chunk_data_list(data_list: List[Dict]): + result = [chunk_data(data) for data in data_list] + return reduce(lambda x, y: [*x, *y], result, []) + + class BaseVectorStore(ABC): vector_exists = False @@ -64,7 +79,12 @@ class BaseVectorStore(ABC): :return: bool """ self.save_pre_handler() - self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding) + data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'dataset_id': dataset_id, + 'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text} + chunk_list = chunk_data(data) + result = sub_array(chunk_list) + for child_array in result: + self._batch_save(child_array, embedding) def batch_save(self, data_list: List[Dict], embedding: Embeddings): # 获取锁 @@ -77,7 +97,8 @@ class BaseVectorStore(ABC): :return: bool """ self.save_pre_handler() - result = sub_array(data_list) + chunk_list = chunk_data_list(data_list) + result = sub_array(chunk_list) for child_array in result: self._batch_save(child_array, embedding) finally: diff --git a/apps/setting/migrations/0006_alter_model_status.py b/apps/setting/migrations/0006_alter_model_status.py new file mode 100644 index 000000000..209f57c94 --- /dev/null +++ b/apps/setting/migrations/0006_alter_model_status.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.14 on 2024-07-23 18:14 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0005_model_permission_type'), + ] + + operations = [ + migrations.AlterField( + model_name='model', + name='status', + field=models.CharField(choices=[('SUCCESS', '成功'), ('ERROR', '失败'), ('DOWNLOAD', '下载中'), ('PAUSE_DOWNLOAD', '暂停下载')], default='SUCCESS', max_length=20, verbose_name='设置类型'), + ), + ]