mirror of
https://github.com/1Panel-dev/MaxKB.git
synced 2025-12-25 17:22:55 +00:00
feat: 细分段落chunk增加召回命中率 (#841)
This commit is contained in:
parent
203c3e5cde
commit
53434f9d24
|
|
@ -0,0 +1,18 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: __init__.py
|
||||
@date:2024/7/23 17:03
|
||||
@desc:
|
||||
"""
|
||||
from common.chunk.impl.mark_chunk_handle import MarkChunkHandle
|
||||
|
||||
handles = [MarkChunkHandle()]
|
||||
|
||||
|
||||
def text_to_chunk(text: str):
|
||||
chunk_list = [text]
|
||||
for handle in handles:
|
||||
chunk_list = handle.handle(chunk_list)
|
||||
return chunk_list
|
||||
|
|
@ -0,0 +1,16 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: i_chunk_handle.py
|
||||
@date:2024/7/23 16:51
|
||||
@desc:
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
|
||||
class IChunkHandle(ABC):
|
||||
@abstractmethod
|
||||
def handle(self, chunk_list: List[str]):
|
||||
pass
|
||||
|
|
@ -0,0 +1,24 @@
|
|||
# coding=utf-8
|
||||
"""
|
||||
@project: MaxKB
|
||||
@Author:虎
|
||||
@file: mark_chunk_handle.py
|
||||
@date:2024/7/23 16:52
|
||||
@desc:
|
||||
"""
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from common.chunk.i_chunk_handle import IChunkHandle
|
||||
|
||||
split_chunk_pattern = "!|。|\n|;|;"
|
||||
|
||||
|
||||
class MarkChunkHandle(IChunkHandle):
|
||||
def handle(self, chunk_list: List[str]):
|
||||
result = []
|
||||
for chunk in chunk_list:
|
||||
base_chunk = re.split(split_chunk_pattern, chunk)
|
||||
base_chunk = [chunk.strip() for chunk in base_chunk if len(chunk.strip()) > 0]
|
||||
result = [*result, *base_chunk]
|
||||
return result
|
||||
|
|
@ -19,9 +19,7 @@ SELECT
|
|||
paragraph."id" AS paragraph_id,
|
||||
paragraph.dataset_id AS dataset_id,
|
||||
1 AS source_type,
|
||||
concat_ws('
|
||||
',concat_ws('
|
||||
',paragraph.title,paragraph."content"),paragraph.title) AS "text",
|
||||
concat_ws(E'\n',paragraph.title,paragraph."content") AS "text",
|
||||
paragraph.is_active AS is_active
|
||||
FROM
|
||||
paragraph paragraph
|
||||
|
|
|
|||
|
|
@ -0,0 +1,17 @@
|
|||
# Generated by Django 4.2.14 on 2024-07-23 18:14
|
||||
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('embedding', '0002_embedding_search_vector'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterUniqueTogether(
|
||||
name='embedding',
|
||||
unique_together=set(),
|
||||
),
|
||||
]
|
||||
|
|
@ -50,4 +50,3 @@ class Embedding(models.Model):
|
|||
|
||||
class Meta:
|
||||
db_table = "embedding"
|
||||
unique_together = ['source_id', 'source_type']
|
||||
|
|
|
|||
|
|
@ -8,16 +8,31 @@
|
|||
"""
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import reduce
|
||||
from typing import List, Dict
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from common.chunk import text_to_chunk
|
||||
from common.util.common import sub_array
|
||||
from embedding.models import SourceType, SearchMode
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
|
||||
def chunk_data(data: Dict):
|
||||
if str(data.get('source_type')) == SourceType.PARAGRAPH.value:
|
||||
text = data.get('text')
|
||||
chunk_list = text_to_chunk(text)
|
||||
return [{**data, 'text': chunk} for chunk in chunk_list]
|
||||
return [data]
|
||||
|
||||
|
||||
def chunk_data_list(data_list: List[Dict]):
|
||||
result = [chunk_data(data) for data in data_list]
|
||||
return reduce(lambda x, y: [*x, *y], result, [])
|
||||
|
||||
|
||||
class BaseVectorStore(ABC):
|
||||
vector_exists = False
|
||||
|
||||
|
|
@ -64,7 +79,12 @@ class BaseVectorStore(ABC):
|
|||
:return: bool
|
||||
"""
|
||||
self.save_pre_handler()
|
||||
self._save(text, source_type, dataset_id, document_id, paragraph_id, source_id, is_active, embedding)
|
||||
data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'dataset_id': dataset_id,
|
||||
'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text}
|
||||
chunk_list = chunk_data(data)
|
||||
result = sub_array(chunk_list)
|
||||
for child_array in result:
|
||||
self._batch_save(child_array, embedding)
|
||||
|
||||
def batch_save(self, data_list: List[Dict], embedding: Embeddings):
|
||||
# 获取锁
|
||||
|
|
@ -77,7 +97,8 @@ class BaseVectorStore(ABC):
|
|||
:return: bool
|
||||
"""
|
||||
self.save_pre_handler()
|
||||
result = sub_array(data_list)
|
||||
chunk_list = chunk_data_list(data_list)
|
||||
result = sub_array(chunk_list)
|
||||
for child_array in result:
|
||||
self._batch_save(child_array, embedding)
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,18 @@
|
|||
# Generated by Django 4.2.14 on 2024-07-23 18:14
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
('setting', '0005_model_permission_type'),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name='model',
|
||||
name='status',
|
||||
field=models.CharField(choices=[('SUCCESS', '成功'), ('ERROR', '失败'), ('DOWNLOAD', '下载中'), ('PAUSE_DOWNLOAD', '暂停下载')], default='SUCCESS', max_length=20, verbose_name='设置类型'),
|
||||
),
|
||||
]
|
||||
Loading…
Reference in New Issue