diff --git a/apps/application/serializers/application_chat.py b/apps/application/serializers/application_chat.py index b86f06489..a94f3a82e 100644 --- a/apps/application/serializers/application_chat.py +++ b/apps/application/serializers/application_chat.py @@ -24,7 +24,7 @@ from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE from rest_framework import serializers from application.models import Chat, Application, ChatRecord -from common.db.search import get_dynamics_model, native_search, native_page_search +from common.db.search import get_dynamics_model, native_search, native_page_search, native_page_handler from common.exception.app_exception import AppApiException from common.utils.common import get_file_content from maxkb.conf import PROJECT_DIR @@ -95,7 +95,8 @@ class ApplicationChatQuerySerializers(serializers.Serializer): 'trample_num': models.IntegerField(), 'comparer': models.CharField(), 'application_chat.update_time': models.DateTimeField(), - 'application_chat.id': models.UUIDField(), })) + 'application_chat.id': models.UUIDField(), + 'application_chat_record_temp.id': models.UUIDField()})) base_query_dict = {'application_chat.application_id': self.data.get("application_id"), 'application_chat.update_time__gte': start_time, @@ -106,7 +107,6 @@ class ApplicationChatQuerySerializers(serializers.Serializer): if 'username' in self.data and self.data.get('username') is not None: base_query_dict['application_chat.asker__username__icontains'] = self.data.get('username') - if select_ids is not None and len(select_ids) > 0: base_query_dict['application_chat.id__in'] = select_ids base_condition = Q(**base_query_dict) @@ -180,25 +180,26 @@ class ApplicationChatQuerySerializers(serializers.Serializer): str(row.get('create_time').astimezone(pytz.timezone(TIME_ZONE)).strftime('%Y-%m-%d %H:%M:%S') if row.get('create_time') is not None else None)] + @staticmethod + def reset_value(value): + if isinstance(value, str): + value = re.sub(ILLEGAL_CHARACTERS_RE, '', value) + if isinstance(value, datetime.datetime): + eastern = pytz.timezone(TIME_ZONE) + c = datetime.timezone(eastern._utcoffset) + value = value.astimezone(c) + return value + def export(self, data, with_valid=True): if with_valid: self.is_valid(raise_exception=True) ApplicationChatRecordExportRequest(data=data).is_valid(raise_exception=True) - data_list = native_search(self.get_query_set(data.get('select_ids')), - select_string=get_file_content( - os.path.join(PROJECT_DIR, "apps", "application", 'sql', - ('export_application_chat_ee.sql' if ['PE', 'EE'].__contains__( - edition) else 'export_application_chat.sql'))), - with_table_name=False) - - batch_size = 500 - def stream_response(): - workbook = openpyxl.Workbook() - worksheet = workbook.active - worksheet.title = 'Sheet1' - + workbook = openpyxl.Workbook(write_only=True) + worksheet = workbook.create_sheet(title='Sheet1') + current_page = 1 + page_size = 500 headers = [gettext('Conversation ID'), gettext('summary'), gettext('User Questions'), gettext('Problem after optimization'), gettext('answer'), gettext('User feedback'), @@ -207,24 +208,22 @@ class ApplicationChatQuerySerializers(serializers.Serializer): gettext('Annotation'), gettext('USER'), gettext('Consuming tokens'), gettext('Time consumed (s)'), gettext('Question Time')] - for col_idx, header in enumerate(headers, 1): - cell = worksheet.cell(row=1, column=col_idx) - cell.value = header - - for i in range(0, len(data_list), batch_size): - batch_data = data_list[i:i + batch_size] - - for row_idx, row in enumerate(batch_data, start=i + 2): - for col_idx, value in enumerate(self.to_row(row), 1): - cell = worksheet.cell(row=row_idx, column=col_idx) - if isinstance(value, str): - value = re.sub(ILLEGAL_CHARACTERS_RE, '', value) - if isinstance(value, datetime.datetime): - eastern = pytz.timezone(TIME_ZONE) - c = datetime.timezone(eastern._utcoffset) - value = value.astimezone(c) - cell.value = value + worksheet.append(headers) + for data_list in native_page_handler(page_size, self.get_query_set(data.get('select_ids')), + primary_key='application_chat_record_temp.id', + primary_queryset='default_queryset', + get_primary_value=lambda item: item.get('id'), + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', + ('export_application_chat_ee.sql' if ['PE', + 'EE'].__contains__( + edition) else 'export_application_chat.sql'))), + with_table_name=False): + for item in data_list: + row = [self.reset_value(v) for v in self.to_row(item)] + worksheet.append(row) + current_page = current_page + 1 output = BytesIO() workbook.save(output) output.seek(0) diff --git a/apps/application/sql/export_application_chat.sql b/apps/application/sql/export_application_chat.sql index 2a607cd55..96da10873 100644 --- a/apps/application/sql/export_application_chat.sql +++ b/apps/application/sql/export_application_chat.sql @@ -1,4 +1,5 @@ SELECT + application_chat_record_temp.id AS id, application_chat."id" as chat_id, application_chat.abstract as abstract, application_chat_record_temp.problem_text as problem_text, diff --git a/apps/application/sql/export_application_chat_ee.sql b/apps/application/sql/export_application_chat_ee.sql index d0faebdf2..07c41ca1f 100644 --- a/apps/application/sql/export_application_chat_ee.sql +++ b/apps/application/sql/export_application_chat_ee.sql @@ -1,4 +1,5 @@ SELECT + application_chat_record_temp.id AS id, application_chat."id" as chat_id, application_chat.abstract as abstract, application_chat_record_temp.problem_text as problem_text, diff --git a/apps/common/db/search.py b/apps/common/db/search.py index 23d0113d6..2f0110b8f 100644 --- a/apps/common/db/search.py +++ b/apps/common/db/search.py @@ -16,9 +16,10 @@ from common.db.compiler import AppSQLCompiler from common.db.sql_execute import select_one, select_list, update_execute from common.result import Page - # 添加模型缓存 _model_cache = {} + + def get_dynamics_model(attr: dict, table_name='dynamics'): """ 获取一个动态的django模型 @@ -29,24 +30,24 @@ def get_dynamics_model(attr: dict, table_name='dynamics'): # 创建缓存键,基于属性和表名 cache_key = hashlib.md5(f"{table_name}_{str(sorted(attr.items()))}".encode()).hexdigest() # print(f'cache_key: {cache_key}') - + # 如果模型已存在,直接返回缓存的模型 if cache_key in _model_cache: return _model_cache[cache_key] - + attributes = { "__module__": "knowledge.models", "Meta": type("Meta", (), {'db_table': table_name}), **attr } - + # 使用唯一的类名避免冲突 class_name = f'Dynamics_{cache_key[:8]}' model_class = type(class_name, (models.Model,), attributes) - + # 缓存模型 _model_cache[cache_key] = model_class - + return model_class @@ -189,6 +190,51 @@ def native_page_search(current_page: int, page_size: int, queryset: QuerySet | D return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size) +def native_page_handler(page_size: int, + queryset: QuerySet | Dict[str, QuerySet], + select_string: str, + field_replace_dict=None, + with_table_name=False, + primary_key=None, + get_primary_value=None, + primary_queryset: str = None, + ): + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict({**queryset, + primary_queryset: queryset[primary_queryset].order_by( + primary_key)}, select_string, field_replace_dict, with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query(queryset.order_by( + primary_key), select_string, field_replace_dict, with_table_name) + total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql + total = select_one(total_sql, exec_params) + processed_count = 0 + last_id = None + while processed_count < total.get("count"): + if last_id is not None: + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict({**queryset, + primary_queryset: queryset[primary_queryset].filter( + **{f"{primary_key}__gt": last_id}).order_by( + primary_key)}, + select_string, field_replace_dict, + with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query( + queryset.filter(**{f"{primary_key}__gt": last_id}).order_by( + primary_key), + select_string, field_replace_dict, + with_table_name) + limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql( + 0, page_size + ) + page_sql = exec_sql + " " + limit_sql + result = select_list(page_sql, exec_params) + yield result + processed_count += page_size + last_id = get_primary_value(result[-1]) + + def get_field_replace_dict(queryset: QuerySet): """ 获取需要替换的字段 默认 “xxx.xxx”需要被替换成 “xxx”."xxx"