From e9a05b1255d5eafa3dab729d4e5333023296ce93 Mon Sep 17 00:00:00 2001 From: shaohuzhang1 <80892890+shaohuzhang1@users.noreply.github.com> Date: Fri, 24 May 2024 17:59:02 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8Dqa=E7=9F=A5=E8=AF=86?= =?UTF-8?q?=E5=BA=93=E5=AF=BC=E5=85=A5=E5=A4=B1=E8=B4=A5=E9=94=99=E8=AF=AF?= =?UTF-8?q?=20(#536)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/common/handle/base_parse_qa_handle.py | 31 ++++++++++++ .../handle/impl/qa/csv_parse_qa_handle.py | 45 +++++++++-------- .../handle/impl/qa/xls_parse_qa_handle.py | 49 ++++++++++--------- .../handle/impl/qa/xlsx_parse_qa_handle.py | 48 ++++++++++-------- .../serializers/document_serializers.py | 2 +- 5 files changed, 110 insertions(+), 65 deletions(-) diff --git a/apps/common/handle/base_parse_qa_handle.py b/apps/common/handle/base_parse_qa_handle.py index 948b8e536..79c49dbca 100644 --- a/apps/common/handle/base_parse_qa_handle.py +++ b/apps/common/handle/base_parse_qa_handle.py @@ -9,6 +9,37 @@ from abc import ABC, abstractmethod +def get_row_value(row, title_row_index_dict, field): + index = title_row_index_dict.get(field) + if index is None: + return None + if (len(row) - 1) >= index: + return row[index] + return None + + +def get_title_row_index_dict(title_row_list): + title_row_index_dict = {} + if len(title_row_list) == 1: + title_row_index_dict['content'] = 0 + elif len(title_row_list) == 1: + title_row_index_dict['title'] = 0 + title_row_index_dict['content'] = 1 + else: + title_row_index_dict['title'] = 0 + title_row_index_dict['content'] = 1 + title_row_index_dict['problem_list'] = 2 + for index in range(len(title_row_list)): + title_row = title_row_list[index] + if title_row.startswith('分段标题'): + title_row_index_dict['title'] = index + if title_row.startswith('分段内容'): + title_row_index_dict['content'] = index + if title_row.startswith('问题'): + title_row_index_dict['problem_list'] = index + return title_row_index_dict + + class BaseParseQAHandle(ABC): @abstractmethod def support(self, file, get_buffer): diff --git a/apps/common/handle/impl/qa/csv_parse_qa_handle.py b/apps/common/handle/impl/qa/csv_parse_qa_handle.py index 6952bf8d8..30fe3d78b 100644 --- a/apps/common/handle/impl/qa/csv_parse_qa_handle.py +++ b/apps/common/handle/impl/qa/csv_parse_qa_handle.py @@ -11,7 +11,7 @@ import io from charset_normalizer import detect -from common.handle.base_parse_qa_handle import BaseParseQAHandle +from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value def read_csv_standard(file_path): @@ -32,25 +32,28 @@ class CsvParseQAHandle(BaseParseQAHandle): def handle(self, file, get_buffer): buffer = get_buffer(file) - reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding'])) try: - title_row_list = reader.__next__() + reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding'])) + try: + title_row_list = reader.__next__() + except Exception as e: + return [{'name': file.name, 'paragraphs': []}] + if len(title_row_list) == 0: + return [{'name': file.name, 'paragraphs': []}] + title_row_index_dict = get_title_row_index_dict(title_row_list) + paragraph_list = [] + for row in reader: + content = get_row_value(row, title_row_index_dict, 'content') + if content is None: + continue + problem = get_row_value(row, title_row_index_dict, 'problem_list') + problem = str(problem) if problem is not None else '' + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + title = get_row_value(row, title_row_index_dict, 'title') + title = str(title) if title is not None else '' + paragraph_list.append({'title': title[0:255], + 'content': content[0:4096], + 'problem_list': problem_list}) + return [{'name': file.name, 'paragraphs': paragraph_list}] except Exception as e: - return [] - title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2} - for index in range(len(title_row_list)): - title_row = title_row_list[index] - if title_row.startswith('分段标题'): - title_row_index_dict['title'] = index - if title_row.startswith('分段内容'): - title_row_index_dict['content'] = index - if title_row.startswith('问题'): - title_row_index_dict['problem_list'] = index - paragraph_list = [] - for row in reader: - problem = row[title_row_index_dict.get('problem_list')] - problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] - paragraph_list.append({'title': row[title_row_index_dict.get('title')][0:255], - 'content': row[title_row_index_dict.get('content')][0:4096], - 'problem_list': problem_list}) - return [{'name': file.name, 'paragraphs': paragraph_list}] + return [{'name': file.name, 'paragraphs': []}] diff --git a/apps/common/handle/impl/qa/xls_parse_qa_handle.py b/apps/common/handle/impl/qa/xls_parse_qa_handle.py index 090745141..88e851ff4 100644 --- a/apps/common/handle/impl/qa/xls_parse_qa_handle.py +++ b/apps/common/handle/impl/qa/xls_parse_qa_handle.py @@ -9,7 +9,7 @@ import xlrd -from common.handle.base_parse_qa_handle import BaseParseQAHandle +from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value def handle_sheet(file_name, sheet): @@ -17,22 +17,22 @@ def handle_sheet(file_name, sheet): try: title_row_list = next(rows) except Exception as e: - return None - title_row_index_dict = {'title': 0, 'content': 1, 'problem_list': 2} - for index in range(len(title_row_list)): - title_row = str(title_row_list[index]) - if title_row.startswith('分段标题'): - title_row_index_dict['title'] = index - if title_row.startswith('分段内容'): - title_row_index_dict['content'] = index - if title_row.startswith('问题'): - title_row_index_dict['problem_list'] = index + return {'name': file_name, 'paragraphs': []} + if len(title_row_list) == 0: + return {'name': file_name, 'paragraphs': []} + title_row_index_dict = get_title_row_index_dict(title_row_list) paragraph_list = [] for row in rows: - problem = str(row[title_row_index_dict.get('problem_list')]) + content = get_row_value(row, title_row_index_dict, 'content') + if content is None: + continue + problem = get_row_value(row, title_row_index_dict, 'problem_list') + problem = str(problem) if problem is not None else '' problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] - paragraph_list.append({'title': str(row[title_row_index_dict.get('title')])[0:255], - 'content': str(row[title_row_index_dict.get('content')])[0:4096], + title = get_row_value(row, title_row_index_dict, 'title') + title = str(title) if title is not None else '' + paragraph_list.append({'title': title[0:255], + 'content': content[0:4096], 'problem_list': problem_list}) return {'name': file_name, 'paragraphs': paragraph_list} @@ -40,16 +40,21 @@ def handle_sheet(file_name, sheet): class XlsParseQAHandle(BaseParseQAHandle): def support(self, file, get_buffer): file_name: str = file.name.lower() - if file_name.endswith(".xls"): + buffer = get_buffer(file) + if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer): return True return False def handle(self, file, get_buffer): buffer = get_buffer(file) - workbook = xlrd.open_workbook(file_contents=buffer) - worksheets = workbook.sheets() - worksheets_size = len(worksheets) - return [row for row in - [handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet( - sheet.name, sheet) for sheet - in worksheets] if row is not None] + try: + workbook = xlrd.open_workbook(file_contents=buffer) + worksheets = workbook.sheets() + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, + sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet( + sheet.name, sheet) for sheet + in worksheets] if row is not None] + except Exception as e: + return [{'name': file.name, 'paragraphs': []}] diff --git a/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py b/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py index 9ed0ac06a..d8dd97160 100644 --- a/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py +++ b/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py @@ -10,30 +10,32 @@ import io import openpyxl -from common.handle.base_parse_qa_handle import BaseParseQAHandle +from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value def handle_sheet(file_name, sheet): rows = sheet.rows try: title_row_list = next(rows) + title_row_list = [row.value for row in title_row_list] except Exception as e: - return None - title_row_index_dict = {} - for index in range(len(title_row_list)): - title_row = str(title_row_list[index].value) - if title_row.startswith('分段标题'): - title_row_index_dict['title'] = index - if title_row.startswith('分段内容'): - title_row_index_dict['content'] = index - if title_row.startswith('问题'): - title_row_index_dict['problem_list'] = index + return {'name': file_name, 'paragraphs': []} + if len(title_row_list) == 0: + return {'name': file_name, 'paragraphs': []} + title_row_index_dict = get_title_row_index_dict(title_row_list) paragraph_list = [] for row in rows: - problem = str(row[title_row_index_dict.get('problem_list')].value) + content = get_row_value(row, title_row_index_dict, 'content') + if content is None: + continue + problem = get_row_value(row, title_row_index_dict, 'problem_list') + problem = str(problem.value) if problem is not None else '' problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] - paragraph_list.append({'title': str(row[title_row_index_dict.get('title')].value)[0:255], - 'content': str(row[title_row_index_dict.get('content')].value)[0:4096], + title = get_row_value(row, title_row_index_dict, 'title') + title = str(title.value) if title is not None else '' + content = content.value + paragraph_list.append({'title': title[0:255], + 'content': content[0:4096], 'problem_list': problem_list}) return {'name': file_name, 'paragraphs': paragraph_list} @@ -47,10 +49,14 @@ class XlsxParseQAHandle(BaseParseQAHandle): def handle(self, file, get_buffer): buffer = get_buffer(file) - workbook = openpyxl.load_workbook(io.BytesIO(buffer)) - worksheets = workbook.worksheets - worksheets_size = len(worksheets) - return [row for row in - [handle_sheet(file.name, sheet) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet( - sheet.title, sheet) for sheet - in worksheets] if row is not None] + try: + workbook = openpyxl.load_workbook(io.BytesIO(buffer)) + worksheets = workbook.worksheets + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, + sheet) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet( + sheet.title, sheet) for sheet + in worksheets] if row is not None] + except Exception as e: + return [{'name': file.name, 'paragraphs': []}] diff --git a/apps/dataset/serializers/document_serializers.py b/apps/dataset/serializers/document_serializers.py index e97b2279c..b596e7108 100644 --- a/apps/dataset/serializers/document_serializers.py +++ b/apps/dataset/serializers/document_serializers.py @@ -523,8 +523,8 @@ class DocumentSerializers(ApiMixin, serializers.Serializer): @staticmethod def parse_qa_file(file): + get_buffer = FileBufferHandle().get_buffer for parse_qa_handle in parse_qa_handle_list: - get_buffer = FileBufferHandle().get_buffer if parse_qa_handle.support(file, get_buffer): return parse_qa_handle.handle(file, get_buffer) raise AppApiException(500, '不支持的文件格式')