From 0624d44d1f6b210f734b2c4f2755ffc03e0af77d Mon Sep 17 00:00:00 2001 From: Yilun Huang Date: Tue, 14 Jan 2025 19:15:48 +0800 Subject: [PATCH] log summarization (#534) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * * separate logs from debug/error/warning * * make log summarizations after OPs are done. * make log info aware of specific op_name * pre commit * * resolve the unpacking error due to an extra group for op_name is added. * reorganize the error list to a table * - set name before initialize super class for FusedFilter --------- Co-authored-by: 道辕 --- data_juicer/core/data.py | 4 + data_juicer/ops/base_op.py | 48 ++++++----- .../ops/filter/language_id_score_filter.py | 3 - data_juicer/ops/op_fusion.py | 2 +- data_juicer/utils/logger_utils.py | 79 +++++++++++++++++++ 5 files changed, 112 insertions(+), 24 deletions(-) diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py index d0f8083e1..f9af23f00 100644 --- a/data_juicer/core/data.py +++ b/data_juicer/core/data.py @@ -21,6 +21,7 @@ cleanup_compressed_cache_files, compress, decompress) from data_juicer.utils.fingerprint_utils import generate_fingerprint +from data_juicer.utils.logger_utils import make_log_summarization from data_juicer.utils.process_utils import setup_mp @@ -258,6 +259,9 @@ def process( if work_dir and enable_insight_mining: logger.info('Insight mining for each OP...') adapter.insight_mining() + # make summarization on the error/warning logs + if work_dir: + make_log_summarization() return dataset def update_args(self, args, kargs, is_filter=False): diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py index f230e600b..698203f37 100644 --- a/data_juicer/ops/base_op.py +++ b/data_juicer/ops/base_op.py @@ -1,5 +1,4 @@ import copy -import traceback from functools import wraps import numpy as np @@ -48,11 +47,14 @@ def wrapper(sample, *args, **kwargs): return wrapper -def catch_map_batches_exception(method): +def catch_map_batches_exception(method, op_name=None): """ For batched-map sample-level fault tolerance. """ + if op_name is None: + op_name = method.__name__ + @wraps(method) @convert_arrow_to_python def wrapper(samples, *args, **kwargs): @@ -60,10 +62,8 @@ def wrapper(samples, *args, **kwargs): return method(samples, *args, **kwargs) except Exception as e: from loguru import logger - logger.error( - f'An error occurred in mapper operation when processing ' - f'samples {samples}, {type(e)}: {e}') - traceback.print_exc() + logger.error(f'An error occurred in {op_name} when processing ' + f'samples "{samples}" -- {type(e)}: {e}') ret = {key: [] for key in samples.keys()} ret[Fields.stats] = [] ret[Fields.source_file] = [] @@ -72,12 +72,15 @@ def wrapper(samples, *args, **kwargs): return wrapper -def catch_map_single_exception(method, return_sample=True): +def catch_map_single_exception(method, return_sample=True, op_name=None): """ For single-map sample-level fault tolerance. The input sample is expected batch_size = 1. """ + if op_name is None: + op_name = method.__name__ + def is_batched(sample): val_iter = iter(sample.values()) first_val = next(val_iter) @@ -101,10 +104,8 @@ def wrapper(sample, *args, **kwargs): return [res] except Exception as e: from loguru import logger - logger.error( - f'An error occurred in mapper operation when processing ' - f'sample {sample}, {type(e)}: {e}') - traceback.print_exc() + logger.error(f'An error occurred in {op_name} when processing ' + f'sample "{sample}" -- {type(e)}: {e}') ret = {key: [] for key in sample.keys()} ret[Fields.stats] = [] ret[Fields.source_file] = [] @@ -277,9 +278,11 @@ def __init__(self, *args, **kwargs): # runtime wrappers if self.is_batched_op(): - self.process = catch_map_batches_exception(self.process_batched) + self.process = catch_map_batches_exception(self.process_batched, + op_name=self._name) else: - self.process = catch_map_single_exception(self.process_single) + self.process = catch_map_single_exception(self.process_single, + op_name=self._name) # set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): @@ -366,13 +369,15 @@ def __init__(self, *args, **kwargs): # runtime wrappers if self.is_batched_op(): self.compute_stats = catch_map_batches_exception( - self.compute_stats_batched) - self.process = catch_map_batches_exception(self.process_batched) + self.compute_stats_batched, op_name=self._name) + self.process = catch_map_batches_exception(self.process_batched, + op_name=self._name) else: self.compute_stats = catch_map_single_exception( - self.compute_stats_single) + self.compute_stats_single, op_name=self._name) self.process = catch_map_single_exception(self.process_single, - return_sample=False) + return_sample=False, + op_name=self._name) # set the process method is not allowed to be overridden def __init_subclass__(cls, **kwargs): @@ -481,9 +486,11 @@ def __init__(self, *args, **kwargs): # runtime wrappers if self.is_batched_op(): - self.compute_hash = catch_map_batches_exception(self.compute_hash) + self.compute_hash = catch_map_batches_exception(self.compute_hash, + op_name=self._name) else: - self.compute_hash = catch_map_single_exception(self.compute_hash) + self.compute_hash = catch_map_single_exception(self.compute_hash, + op_name=self._name) def compute_hash(self, sample): """ @@ -619,7 +626,8 @@ def __init__(self, *args, **kwargs): queries and responses """ super(Aggregator, self).__init__(*args, **kwargs) - self.process = catch_map_single_exception(self.process_single) + self.process = catch_map_single_exception(self.process_single, + op_name=self._name) def process_single(self, sample): """ diff --git a/data_juicer/ops/filter/language_id_score_filter.py b/data_juicer/ops/filter/language_id_score_filter.py index 3d97a4424..5df71524e 100644 --- a/data_juicer/ops/filter/language_id_score_filter.py +++ b/data_juicer/ops/filter/language_id_score_filter.py @@ -1,7 +1,5 @@ from typing import List, Union -from loguru import logger - from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.lazy_loader import LazyLoader from data_juicer.utils.model_utils import get_model, prepare_model @@ -55,7 +53,6 @@ def compute_stats_single(self, sample): ft_model = get_model(self.model_key) if ft_model is None: err_msg = 'Model not loaded. Please retry later.' - logger.error(err_msg) raise ValueError(err_msg) pred = ft_model.predict(text) lang_id = pred[0][0].replace('__label__', '') diff --git a/data_juicer/ops/op_fusion.py b/data_juicer/ops/op_fusion.py index 489f90ab0..71e550cb4 100644 --- a/data_juicer/ops/op_fusion.py +++ b/data_juicer/ops/op_fusion.py @@ -156,8 +156,8 @@ def __init__(self, name: str, fused_filters: List): :param fused_filters: a list of filters to be fused. """ - super().__init__() self._name = name + super().__init__() self.fused_filters = fused_filters # set accelerator to 'cuda' if there exists any ops whose accelerator # is 'cuda' diff --git a/data_juicer/utils/logger_utils.py b/data_juicer/utils/logger_utils.py index a91f610fe..11cbf85b8 100644 --- a/data_juicer/utils/logger_utils.py +++ b/data_juicer/utils/logger_utils.py @@ -22,6 +22,9 @@ from loguru import logger from loguru._file_sink import FileSink +from tabulate import tabulate + +from data_juicer.utils.file_utils import add_suffix_to_filename LOGGER_SETUP = False @@ -142,12 +145,88 @@ def setup_logger(save_dir, ) logger.add(save_file) + # for interest of levels: debug, error, warning + logger.add( + add_suffix_to_filename(save_file, '_DEBUG'), + level='DEBUG', + filter=lambda x: 'DEBUG' == x['level'].name, + format=loguru_format, + enqueue=True, + serialize=True, + ) + logger.add( + add_suffix_to_filename(save_file, '_ERROR'), + level='ERROR', + filter=lambda x: 'ERROR' == x['level'].name, + format=loguru_format, + enqueue=True, + serialize=True, + ) + logger.add( + add_suffix_to_filename(save_file, '_WARNING'), + level='WARNING', + filter=lambda x: 'WARNING' == x['level'].name, + format=loguru_format, + enqueue=True, + serialize=True, + ) + # redirect stdout/stderr to loguru if redirect: redirect_sys_output(level) LOGGER_SETUP = True +def make_log_summarization(max_show_item=10): + error_pattern = r'^An error occurred in (.*?) when ' \ + r'processing samples? \"(.*?)\" -- (.*?): (.*?)$' + log_file = get_log_file_path() + error_log_file = add_suffix_to_filename(log_file, '_ERROR') + warning_log_file = add_suffix_to_filename(log_file, '_WARNING') + + import jsonlines as jl + import regex as re + + # make error summarization + error_dict = {} + with jl.open(error_log_file) as reader: + for error_log in reader: + error_msg = error_log['record']['message'] + find_res = re.findall(error_pattern, error_msg) + if len(find_res) > 0: + op_name, sample, error_type, error_msg = find_res[0] + error = (op_name, error_type, error_msg) + error_dict.setdefault(error, 0) + error_dict[error] += 1 + total_error_count = sum(error_dict.values()) + # make warning summarization + warning_count = 0 + with jl.open(warning_log_file) as reader: + for _ in reader: + warning_count += 1 + # make summary log + summary = f'Processing finished with:\n' \ + f'Warnings: {warning_count}\n' \ + f'Errors: {total_error_count}\n' + error_items = list(error_dict.items()) + error_items.sort(key=lambda it: it[1], reverse=True) + error_items = error_items[:max_show_item] + # convert error items to a table + if len(error_items) > 0: + error_table = [] + table_header = [ + 'OP/Method', 'Error Type', 'Error Message', 'Error Count' + ] + for key, num in error_items: + op_name, error_type, error_msg = key + error_table.append([op_name, error_type, error_msg, num]) + table = tabulate(error_table, table_header, tablefmt='fancy_grid') + summary += table + summary += f'\nError/Warning details can be found in the log file ' \ + f'[{log_file}] and its related log files.' + logger.opt(ansi=True).info(summary) + + class HiddenPrints: """Define a range that hide the outputs within this range."""