Skip to content
This repository has been archived by the owner on Apr 8, 2024. It is now read-only.

Commit

Permalink
Extend coverage: add unit tests where they miss most (#195)
Browse files Browse the repository at this point in the history
* add unit tests for lightgbm utils
* add coverage config, remove tasks from coverage
* add tests for data2bin
* remove distributed tests for now
  • Loading branch information
jfomhover authored Dec 9, 2021
1 parent bf90d31 commit c7298fa
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 0 deletions.
5 changes: 5 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[report]
exclude_lines =
pragma: no cover
omit =
src/common/tasks.py
48 changes: 48 additions & 0 deletions tests/common/test_lightgbm_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Tests src/common/io.py"""
import os
import pytest
from unittest.mock import call, Mock, patch

from common.lightgbm_utils import LightGBMCallbackHandler
from lightgbm.callback import CallbackEnv

def test_lightgbm_callback_handler():
metrics_logger = Mock()

callback_handler = LightGBMCallbackHandler(
metrics_logger, metrics_prefix=None, metrics_suffix=None
)

# namedtuple
# see https://lightgbm.readthedocs.io/en/latest/_modules/lightgbm/callback.html
callback_env = CallbackEnv(
None, # model
{"foo_param": 0.32}, # params
3, # iteration
0, # begin_iteration
5, # end_iteration
[
# list of tuples
(
"valid_0", # dataset name
"rmse", # evaluation name
12345.0, # result
None, # _
),
(
"valid_0", # dataset name
"l2", # evaluation name
3456.0, # result
None, # _
)
]
)
callback_handler.callback(callback_env)

metrics_logger.log_metric.assert_has_calls(
[
call(key="valid_0.rmse", value=12345.0, step=3),
call(key="valid_0.l2", value=3456.0, step=3)
],
any_order=True
)
37 changes: 37 additions & 0 deletions tests/scripts/test_lightgbm_data2bin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
test src/scripts/partition_data/partition.py
"""
import os
import sys
import tempfile
import pytest
from unittest.mock import patch

from scripts.data_processing.lightgbm_data2bin import data2bin

# IMPORTANT: see conftest.py for fixtures

def test_lightgbm_data2bin(temporary_dir, regression_train_sample, regression_test_sample):
"""Tests src/scripts/data_processing/lightgbm_data2bin/data2bin.py"""
binary_train_data_dir = os.path.join(temporary_dir, "binary_train_data")
binary_test_data_dir = os.path.join(temporary_dir, "binary_test_data")

# create test arguments for the script

script_args = [
"data2bin.py",
"--train", regression_train_sample,
"--test", regression_test_sample,
"--output_train", binary_train_data_dir,
"--output_test", binary_test_data_dir,
"--header", "False",
"--label_column", "0",
"--max_bin", "255",
]

# replaces sys.argv with test arguments and run main
with patch.object(sys, "argv", script_args):
data2bin.main()

assert os.path.isfile(os.path.join(binary_train_data_dir, "train.bin"))
assert os.path.isfile(os.path.join(binary_test_data_dir, "test_0.bin"))

0 comments on commit c7298fa

Please sign in to comment.