Skip to content

Commit

Permalink
combine the two tests
Browse files Browse the repository at this point in the history
add assertion for passed arguments
rename mock object
  • Loading branch information
root committed Dec 20, 2024
1 parent 6352b90 commit c9a6794
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions tests/program_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,31 @@
import pandas as pd

from bimorph_mirror_analysis.__main__ import calculate_optimal_voltages


def test_calculate_optimal_voltages(raw_data: pd.DataFrame):
with patch("bimorph_mirror_analysis.read_file.pd.read_csv") as mock_read_csv:
mock_read_csv.return_value = raw_data
voltages = calculate_optimal_voltages("input_file")
voltages = np.round(voltages, 2)
np.testing.assert_almost_equal(voltages, np.array([72.14, 50.98, 18.59]))
from bimorph_mirror_analysis.maths import find_voltages


def test_calculate_optimal_voltages_mocked(raw_data_pivoted: pd.DataFrame):
with (
patch(
"bimorph_mirror_analysis.__main__.read_bluesky_plan_output"
) as mock_read_bluesky_plan_output,
patch(
"bimorph_mirror_analysis.__main__.find_voltages"
) as mock_calculate_voltages,
patch("bimorph_mirror_analysis.__main__.find_voltages") as mock_find_voltages,
):
# set the mock return values
mock_read_bluesky_plan_output.return_value = (
raw_data_pivoted,
np.array([0.0, 0.0, 0.0]),
100,
)
mock_calculate_voltages.return_value = np.array([72.14, 50.98, 18.59])
calculate_optimal_voltages("input_file")
mock_find_voltages.side_effect = find_voltages
voltages = calculate_optimal_voltages("input_file")
voltages = np.round(voltages, 2)
# assert correct voltages calculated
np.testing.assert_almost_equal(voltages, np.array([72.14, 50.98, 18.59]))

# assert mock was called
mock_read_bluesky_plan_output.assert_called()
mock_calculate_voltages.assert_called()
mock_read_bluesky_plan_output.assert_called_with("input_file")
mock_find_voltages.assert_called()
expected_data = raw_data_pivoted[raw_data_pivoted.columns[1:]].to_numpy() # type: ignore
np.testing.assert_array_equal(mock_find_voltages.call_args[0][0], expected_data) # type: ignore

0 comments on commit c9a6794

Please sign in to comment.