feat: 细分段落chunk增加召回命中率 (#841)

This commit is contained in:
shaohuzhang1 2024-07-23 18:19:41 +08:00 committed by GitHub
parent 203c3e5cde
commit 53434f9d24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 117 additions and 6 deletions

View File

@ -0,0 +1,18 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file __init__.py
@date2024/7/23 17:03
@desc:
"""
from common.chunk.impl.mark_chunk_handle import MarkChunkHandle
handles = [MarkChunkHandle()]
def text_to_chunk(text: str):
chunk_list = [text]
for handle in handles:
chunk_list = handle.handle(chunk_list)
return chunk_list

View File

@ -0,0 +1,16 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file i_chunk_handle.py
@date2024/7/23 16:51
@desc:
"""
from abc import ABC, abstractmethod
from typing import List
class IChunkHandle(ABC):
@abstractmethod
def handle(self, chunk_list: List[str]):
pass

View File

@ -0,0 +1,24 @@
# coding=utf-8
"""
@project: MaxKB
@Author
@file mark_chunk_handle.py
@date2024/7/23 16:52
@desc:
"""
import re
from typing import List
from common.chunk.i_chunk_handle import IChunkHandle
split_chunk_pattern = "|。|\n||;"
class MarkChunkHandle(IChunkHandle):
def handle(self, chunk_list: List[str]):
result = []
for chunk in chunk_list:
base_chunk = re.split(split_chunk_pattern, chunk)
base_chunk = [chunk.strip() for chunk in base_chunk if len(chunk.strip()) > 0]
result = [*result, *base_chunk]
return result

View File

@ -19,9 +19,7 @@ SELECT
paragraph."id" AS paragraph_id,
paragraph.dataset_id AS dataset_id,
1 AS source_type,
concat_ws('
',concat_ws('
',paragraph.title,paragraph."content"),paragraph.title) AS "text",
concat_ws(E'\n',paragraph.title,paragraph."content") AS "text",
paragraph.is_active AS is_active
FROM
paragraph paragraph

View File

@ -0,0 +1,17 @@
# Generated by Django 4.2.14 on 2024-07-23 18:14
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('embedding', '0002_embedding_search_vector'),
]
operations = [
migrations.AlterUniqueTogether(
name='embedding',
unique_together=set(),
),
]

View File

@ -50,4 +50,3 @@ class Embedding(models.Model):
class Meta:
db_table = "embedding"
unique_together = ['source_id', 'source_type']

View File

@ -8,16 +8,31 @@
"""
import threading
from abc import ABC, abstractmethod
from functools import reduce
from typing import List, Dict
from langchain_core.embeddings import Embeddings
from common.chunk import text_to_chunk
from common.util.common import sub_array
from embedding.models import SourceType, SearchMode
lock = threading.Lock()
def chunk_data(data: Dict):
if str(data.get('source_type')) == SourceType.PARAGRAPH.value:
text = data.get('text')
chunk_list = text_to_chunk(text)
return [{**data, 'text': chunk} for chunk in chunk_list]
return [data]
def chunk_data_list(data_list: List[Dict]):
result = [chunk_data(data) for data in data_list]
return reduce(lambda x, y: [*x, *y], result, [])
class BaseVectorStore(ABC):
vector_exists = False
@ -64,7 +79,12 @@ class BaseVectorStore(ABC):
:return: bool
"""
self.save_pre_handler()
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'dataset_id': dataset_id,
'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text}
chunk_list = chunk_data(data)
result = sub_array(chunk_list)
for child_array in result:
self._batch_save(child_array, embedding)
def batch_save(self, data_list: List[Dict], embedding: Embeddings):
# 获取锁
@ -77,7 +97,8 @@ class BaseVectorStore(ABC):
:return: bool
"""
self.save_pre_handler()
result = sub_array(data_list)
chunk_list = chunk_data_list(data_list)
result = sub_array(chunk_list)
for child_array in result:
self._batch_save(child_array, embedding)
finally:

View File

@ -0,0 +1,18 @@
# Generated by Django 4.2.14 on 2024-07-23 18:14
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('setting', '0005_model_permission_type'),
]
operations = [
migrations.AlterField(
model_name='model',
name='status',
field=models.CharField(choices=[('SUCCESS', '成功'), ('ERROR', '失败'), ('DOWNLOAD', '下载中'), ('PAUSE_DOWNLOAD', '暂停下载')], default='SUCCESS', max_length=20, verbose_name='设置类型'),
),
]