diff --git a/data_juicer/core/data.py b/data_juicer/core/data.py
index d0f8083e1..f9af23f00 100644
--- a/data_juicer/core/data.py
+++ b/data_juicer/core/data.py
@@ -21,6 +21,7 @@
cleanup_compressed_cache_files,
compress, decompress)
from data_juicer.utils.fingerprint_utils import generate_fingerprint
+from data_juicer.utils.logger_utils import make_log_summarization
from data_juicer.utils.process_utils import setup_mp
@@ -258,6 +259,9 @@ def process(
if work_dir and enable_insight_mining:
logger.info('Insight mining for each OP...')
adapter.insight_mining()
+ # make summarization on the error/warning logs
+ if work_dir:
+ make_log_summarization()
return dataset
def update_args(self, args, kargs, is_filter=False):
diff --git a/data_juicer/ops/base_op.py b/data_juicer/ops/base_op.py
index f230e600b..698203f37 100644
--- a/data_juicer/ops/base_op.py
+++ b/data_juicer/ops/base_op.py
@@ -1,5 +1,4 @@
import copy
-import traceback
from functools import wraps
import numpy as np
@@ -48,11 +47,14 @@ def wrapper(sample, *args, **kwargs):
return wrapper
-def catch_map_batches_exception(method):
+def catch_map_batches_exception(method, op_name=None):
"""
For batched-map sample-level fault tolerance.
"""
+ if op_name is None:
+ op_name = method.__name__
+
@wraps(method)
@convert_arrow_to_python
def wrapper(samples, *args, **kwargs):
@@ -60,10 +62,8 @@ def wrapper(samples, *args, **kwargs):
return method(samples, *args, **kwargs)
except Exception as e:
from loguru import logger
- logger.error(
- f'An error occurred in mapper operation when processing '
- f'samples {samples}, {type(e)}: {e}')
- traceback.print_exc()
+ logger.error(f'An error occurred in {op_name} when processing '
+ f'samples "{samples}" -- {type(e)}: {e}')
ret = {key: [] for key in samples.keys()}
ret[Fields.stats] = []
ret[Fields.source_file] = []
@@ -72,12 +72,15 @@ def wrapper(samples, *args, **kwargs):
return wrapper
-def catch_map_single_exception(method, return_sample=True):
+def catch_map_single_exception(method, return_sample=True, op_name=None):
"""
For single-map sample-level fault tolerance.
The input sample is expected batch_size = 1.
"""
+ if op_name is None:
+ op_name = method.__name__
+
def is_batched(sample):
val_iter = iter(sample.values())
first_val = next(val_iter)
@@ -101,10 +104,8 @@ def wrapper(sample, *args, **kwargs):
return [res]
except Exception as e:
from loguru import logger
- logger.error(
- f'An error occurred in mapper operation when processing '
- f'sample {sample}, {type(e)}: {e}')
- traceback.print_exc()
+ logger.error(f'An error occurred in {op_name} when processing '
+ f'sample "{sample}" -- {type(e)}: {e}')
ret = {key: [] for key in sample.keys()}
ret[Fields.stats] = []
ret[Fields.source_file] = []
@@ -277,9 +278,11 @@ def __init__(self, *args, **kwargs):
# runtime wrappers
if self.is_batched_op():
- self.process = catch_map_batches_exception(self.process_batched)
+ self.process = catch_map_batches_exception(self.process_batched,
+ op_name=self._name)
else:
- self.process = catch_map_single_exception(self.process_single)
+ self.process = catch_map_single_exception(self.process_single,
+ op_name=self._name)
# set the process method is not allowed to be overridden
def __init_subclass__(cls, **kwargs):
@@ -366,13 +369,15 @@ def __init__(self, *args, **kwargs):
# runtime wrappers
if self.is_batched_op():
self.compute_stats = catch_map_batches_exception(
- self.compute_stats_batched)
- self.process = catch_map_batches_exception(self.process_batched)
+ self.compute_stats_batched, op_name=self._name)
+ self.process = catch_map_batches_exception(self.process_batched,
+ op_name=self._name)
else:
self.compute_stats = catch_map_single_exception(
- self.compute_stats_single)
+ self.compute_stats_single, op_name=self._name)
self.process = catch_map_single_exception(self.process_single,
- return_sample=False)
+ return_sample=False,
+ op_name=self._name)
# set the process method is not allowed to be overridden
def __init_subclass__(cls, **kwargs):
@@ -481,9 +486,11 @@ def __init__(self, *args, **kwargs):
# runtime wrappers
if self.is_batched_op():
- self.compute_hash = catch_map_batches_exception(self.compute_hash)
+ self.compute_hash = catch_map_batches_exception(self.compute_hash,
+ op_name=self._name)
else:
- self.compute_hash = catch_map_single_exception(self.compute_hash)
+ self.compute_hash = catch_map_single_exception(self.compute_hash,
+ op_name=self._name)
def compute_hash(self, sample):
"""
@@ -619,7 +626,8 @@ def __init__(self, *args, **kwargs):
queries and responses
"""
super(Aggregator, self).__init__(*args, **kwargs)
- self.process = catch_map_single_exception(self.process_single)
+ self.process = catch_map_single_exception(self.process_single,
+ op_name=self._name)
def process_single(self, sample):
"""
diff --git a/data_juicer/ops/filter/language_id_score_filter.py b/data_juicer/ops/filter/language_id_score_filter.py
index 3d97a4424..5df71524e 100644
--- a/data_juicer/ops/filter/language_id_score_filter.py
+++ b/data_juicer/ops/filter/language_id_score_filter.py
@@ -1,7 +1,5 @@
from typing import List, Union
-from loguru import logger
-
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
@@ -55,7 +53,6 @@ def compute_stats_single(self, sample):
ft_model = get_model(self.model_key)
if ft_model is None:
err_msg = 'Model not loaded. Please retry later.'
- logger.error(err_msg)
raise ValueError(err_msg)
pred = ft_model.predict(text)
lang_id = pred[0][0].replace('__label__', '')
diff --git a/data_juicer/ops/op_fusion.py b/data_juicer/ops/op_fusion.py
index 489f90ab0..71e550cb4 100644
--- a/data_juicer/ops/op_fusion.py
+++ b/data_juicer/ops/op_fusion.py
@@ -156,8 +156,8 @@ def __init__(self, name: str, fused_filters: List):
:param fused_filters: a list of filters to be fused.
"""
- super().__init__()
self._name = name
+ super().__init__()
self.fused_filters = fused_filters
# set accelerator to 'cuda' if there exists any ops whose accelerator
# is 'cuda'
diff --git a/data_juicer/utils/logger_utils.py b/data_juicer/utils/logger_utils.py
index a91f610fe..11cbf85b8 100644
--- a/data_juicer/utils/logger_utils.py
+++ b/data_juicer/utils/logger_utils.py
@@ -22,6 +22,9 @@
from loguru import logger
from loguru._file_sink import FileSink
+from tabulate import tabulate
+
+from data_juicer.utils.file_utils import add_suffix_to_filename
LOGGER_SETUP = False
@@ -142,12 +145,88 @@ def setup_logger(save_dir,
)
logger.add(save_file)
+ # for interest of levels: debug, error, warning
+ logger.add(
+ add_suffix_to_filename(save_file, '_DEBUG'),
+ level='DEBUG',
+ filter=lambda x: 'DEBUG' == x['level'].name,
+ format=loguru_format,
+ enqueue=True,
+ serialize=True,
+ )
+ logger.add(
+ add_suffix_to_filename(save_file, '_ERROR'),
+ level='ERROR',
+ filter=lambda x: 'ERROR' == x['level'].name,
+ format=loguru_format,
+ enqueue=True,
+ serialize=True,
+ )
+ logger.add(
+ add_suffix_to_filename(save_file, '_WARNING'),
+ level='WARNING',
+ filter=lambda x: 'WARNING' == x['level'].name,
+ format=loguru_format,
+ enqueue=True,
+ serialize=True,
+ )
+
# redirect stdout/stderr to loguru
if redirect:
redirect_sys_output(level)
LOGGER_SETUP = True
+def make_log_summarization(max_show_item=10):
+ error_pattern = r'^An error occurred in (.*?) when ' \
+ r'processing samples? \"(.*?)\" -- (.*?): (.*?)$'
+ log_file = get_log_file_path()
+ error_log_file = add_suffix_to_filename(log_file, '_ERROR')
+ warning_log_file = add_suffix_to_filename(log_file, '_WARNING')
+
+ import jsonlines as jl
+ import regex as re
+
+ # make error summarization
+ error_dict = {}
+ with jl.open(error_log_file) as reader:
+ for error_log in reader:
+ error_msg = error_log['record']['message']
+ find_res = re.findall(error_pattern, error_msg)
+ if len(find_res) > 0:
+ op_name, sample, error_type, error_msg = find_res[0]
+ error = (op_name, error_type, error_msg)
+ error_dict.setdefault(error, 0)
+ error_dict[error] += 1
+ total_error_count = sum(error_dict.values())
+ # make warning summarization
+ warning_count = 0
+ with jl.open(warning_log_file) as reader:
+ for _ in reader:
+ warning_count += 1
+ # make summary log
+ summary = f'Processing finished with:\n' \
+ f'Warnings: {warning_count}\n' \
+ f'Errors: {total_error_count}\n'
+ error_items = list(error_dict.items())
+ error_items.sort(key=lambda it: it[1], reverse=True)
+ error_items = error_items[:max_show_item]
+ # convert error items to a table
+ if len(error_items) > 0:
+ error_table = []
+ table_header = [
+ 'OP/Method', 'Error Type', 'Error Message', 'Error Count'
+ ]
+ for key, num in error_items:
+ op_name, error_type, error_msg = key
+ error_table.append([op_name, error_type, error_msg, num])
+ table = tabulate(error_table, table_header, tablefmt='fancy_grid')
+ summary += table
+ summary += f'\nError/Warning details can be found in the log file ' \
+ f'[{log_file}] and its related log files.'
+ logger.opt(ansi=True).info(summary)
+
+
class HiddenPrints:
"""Define a range that hide the outputs within this range."""