Skip to content

Commit

Permalink
log summarization (#534)
Browse files Browse the repository at this point in the history
* * separate logs from debug/error/warning

* * make log summarizations after OPs are done.

* make log info aware of specific op_name

* pre commit

* * resolve the unpacking error due to an extra group for op_name is added.
* reorganize the error list to a table

* - set name before initialize super class for FusedFilter

---------

Co-authored-by: 道辕 <[email protected]>
  • Loading branch information
HYLcool and yxdyc authored Jan 14, 2025
1 parent 50f480b commit 0624d44
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 24 deletions.
4 changes: 4 additions & 0 deletions data_juicer/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
cleanup_compressed_cache_files,
compress, decompress)
from data_juicer.utils.fingerprint_utils import generate_fingerprint
from data_juicer.utils.logger_utils import make_log_summarization
from data_juicer.utils.process_utils import setup_mp


Expand Down Expand Up @@ -258,6 +259,9 @@ def process(
if work_dir and enable_insight_mining:
logger.info('Insight mining for each OP...')
adapter.insight_mining()
# make summarization on the error/warning logs
if work_dir:
make_log_summarization()
return dataset

def update_args(self, args, kargs, is_filter=False):
Expand Down
48 changes: 28 additions & 20 deletions data_juicer/ops/base_op.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
import traceback
from functools import wraps

import numpy as np
Expand Down Expand Up @@ -48,22 +47,23 @@ def wrapper(sample, *args, **kwargs):
return wrapper


def catch_map_batches_exception(method):
def catch_map_batches_exception(method, op_name=None):
"""
For batched-map sample-level fault tolerance.
"""

if op_name is None:
op_name = method.__name__

@wraps(method)
@convert_arrow_to_python
def wrapper(samples, *args, **kwargs):
try:
return method(samples, *args, **kwargs)
except Exception as e:
from loguru import logger
logger.error(
f'An error occurred in mapper operation when processing '
f'samples {samples}, {type(e)}: {e}')
traceback.print_exc()
logger.error(f'An error occurred in {op_name} when processing '
f'samples "{samples}" -- {type(e)}: {e}')
ret = {key: [] for key in samples.keys()}
ret[Fields.stats] = []
ret[Fields.source_file] = []
Expand All @@ -72,12 +72,15 @@ def wrapper(samples, *args, **kwargs):
return wrapper


def catch_map_single_exception(method, return_sample=True):
def catch_map_single_exception(method, return_sample=True, op_name=None):
"""
For single-map sample-level fault tolerance.
The input sample is expected batch_size = 1.
"""

if op_name is None:
op_name = method.__name__

def is_batched(sample):
val_iter = iter(sample.values())
first_val = next(val_iter)
Expand All @@ -101,10 +104,8 @@ def wrapper(sample, *args, **kwargs):
return [res]
except Exception as e:
from loguru import logger
logger.error(
f'An error occurred in mapper operation when processing '
f'sample {sample}, {type(e)}: {e}')
traceback.print_exc()
logger.error(f'An error occurred in {op_name} when processing '
f'sample "{sample}" -- {type(e)}: {e}')
ret = {key: [] for key in sample.keys()}
ret[Fields.stats] = []
ret[Fields.source_file] = []
Expand Down Expand Up @@ -277,9 +278,11 @@ def __init__(self, *args, **kwargs):

# runtime wrappers
if self.is_batched_op():
self.process = catch_map_batches_exception(self.process_batched)
self.process = catch_map_batches_exception(self.process_batched,
op_name=self._name)
else:
self.process = catch_map_single_exception(self.process_single)
self.process = catch_map_single_exception(self.process_single,
op_name=self._name)

# set the process method is not allowed to be overridden
def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -366,13 +369,15 @@ def __init__(self, *args, **kwargs):
# runtime wrappers
if self.is_batched_op():
self.compute_stats = catch_map_batches_exception(
self.compute_stats_batched)
self.process = catch_map_batches_exception(self.process_batched)
self.compute_stats_batched, op_name=self._name)
self.process = catch_map_batches_exception(self.process_batched,
op_name=self._name)
else:
self.compute_stats = catch_map_single_exception(
self.compute_stats_single)
self.compute_stats_single, op_name=self._name)
self.process = catch_map_single_exception(self.process_single,
return_sample=False)
return_sample=False,
op_name=self._name)

# set the process method is not allowed to be overridden
def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -481,9 +486,11 @@ def __init__(self, *args, **kwargs):

# runtime wrappers
if self.is_batched_op():
self.compute_hash = catch_map_batches_exception(self.compute_hash)
self.compute_hash = catch_map_batches_exception(self.compute_hash,
op_name=self._name)
else:
self.compute_hash = catch_map_single_exception(self.compute_hash)
self.compute_hash = catch_map_single_exception(self.compute_hash,
op_name=self._name)

def compute_hash(self, sample):
"""
Expand Down Expand Up @@ -619,7 +626,8 @@ def __init__(self, *args, **kwargs):
queries and responses
"""
super(Aggregator, self).__init__(*args, **kwargs)
self.process = catch_map_single_exception(self.process_single)
self.process = catch_map_single_exception(self.process_single,
op_name=self._name)

def process_single(self, sample):
"""
Expand Down
3 changes: 0 additions & 3 deletions data_juicer/ops/filter/language_id_score_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import List, Union

from loguru import logger

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.model_utils import get_model, prepare_model
Expand Down Expand Up @@ -55,7 +53,6 @@ def compute_stats_single(self, sample):
ft_model = get_model(self.model_key)
if ft_model is None:
err_msg = 'Model not loaded. Please retry later.'
logger.error(err_msg)
raise ValueError(err_msg)
pred = ft_model.predict(text)
lang_id = pred[0][0].replace('__label__', '')
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/op_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def __init__(self, name: str, fused_filters: List):
:param fused_filters: a list of filters to be fused.
"""
super().__init__()
self._name = name
super().__init__()
self.fused_filters = fused_filters
# set accelerator to 'cuda' if there exists any ops whose accelerator
# is 'cuda'
Expand Down
79 changes: 79 additions & 0 deletions data_juicer/utils/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

from loguru import logger
from loguru._file_sink import FileSink
from tabulate import tabulate

from data_juicer.utils.file_utils import add_suffix_to_filename

LOGGER_SETUP = False

Expand Down Expand Up @@ -142,12 +145,88 @@ def setup_logger(save_dir,
)
logger.add(save_file)

# for interest of levels: debug, error, warning
logger.add(
add_suffix_to_filename(save_file, '_DEBUG'),
level='DEBUG',
filter=lambda x: 'DEBUG' == x['level'].name,
format=loguru_format,
enqueue=True,
serialize=True,
)
logger.add(
add_suffix_to_filename(save_file, '_ERROR'),
level='ERROR',
filter=lambda x: 'ERROR' == x['level'].name,
format=loguru_format,
enqueue=True,
serialize=True,
)
logger.add(
add_suffix_to_filename(save_file, '_WARNING'),
level='WARNING',
filter=lambda x: 'WARNING' == x['level'].name,
format=loguru_format,
enqueue=True,
serialize=True,
)

# redirect stdout/stderr to loguru
if redirect:
redirect_sys_output(level)
LOGGER_SETUP = True


def make_log_summarization(max_show_item=10):
error_pattern = r'^An error occurred in (.*?) when ' \
r'processing samples? \"(.*?)\" -- (.*?): (.*?)$'
log_file = get_log_file_path()
error_log_file = add_suffix_to_filename(log_file, '_ERROR')
warning_log_file = add_suffix_to_filename(log_file, '_WARNING')

import jsonlines as jl
import regex as re

# make error summarization
error_dict = {}
with jl.open(error_log_file) as reader:
for error_log in reader:
error_msg = error_log['record']['message']
find_res = re.findall(error_pattern, error_msg)
if len(find_res) > 0:
op_name, sample, error_type, error_msg = find_res[0]
error = (op_name, error_type, error_msg)
error_dict.setdefault(error, 0)
error_dict[error] += 1
total_error_count = sum(error_dict.values())
# make warning summarization
warning_count = 0
with jl.open(warning_log_file) as reader:
for _ in reader:
warning_count += 1
# make summary log
summary = f'Processing finished with:\n' \
f'<yellow>Warnings</yellow>: {warning_count}\n' \
f'<red>Errors</red>: {total_error_count}\n'
error_items = list(error_dict.items())
error_items.sort(key=lambda it: it[1], reverse=True)
error_items = error_items[:max_show_item]
# convert error items to a table
if len(error_items) > 0:
error_table = []
table_header = [
'OP/Method', 'Error Type', 'Error Message', 'Error Count'
]
for key, num in error_items:
op_name, error_type, error_msg = key
error_table.append([op_name, error_type, error_msg, num])
table = tabulate(error_table, table_header, tablefmt='fancy_grid')
summary += table
summary += f'\nError/Warning details can be found in the log file ' \
f'[{log_file}] and its related log files.'
logger.opt(ansi=True).info(summary)


class HiddenPrints:
"""Define a range that hide the outputs within this range."""

Expand Down

0 comments on commit 0624d44

Please sign in to comment.