diff --git a/kerngen/high_parser/parser.py b/kerngen/high_parser/parser.py index f64d5a1..32ad258 100644 --- a/kerngen/high_parser/parser.py +++ b/kerngen/high_parser/parser.py @@ -34,12 +34,22 @@ def __init__(self, iterable, symbols_map): self._commands = list(iterable) self._symbols_map = symbols_map + @staticmethod + def _get_context_from_commands_list(commands): + """Validates that the commands list contains a single context""" + context_list = [context for context in commands if isinstance(context, Context)] + if not context_list: + raise LookupError("No Context found for commands list for ParseResults") + if len(context_list) > 1: + raise LookupError( + "Multiple Context found in commands list for ParseResults" + ) + return context_list[0] + @property def context(self): """Return found context""" - return next( - context for context in self._commands if isinstance(context, Context) - ) + return ParseResults._get_context_from_commands_list(self._commands) @property def commands(self): diff --git a/kerngen/tests/test_kerngen.py b/kerngen/tests/test_kerngen.py index d7b42b4..f607fe8 100644 --- a/kerngen/tests/test_kerngen.py +++ b/kerngen/tests/test_kerngen.py @@ -1,10 +1,13 @@ # Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 """Test the expected behaviour of the kerngen script""" from enum import Enum from pathlib import Path from subprocess import run +from high_parser.parser import ParseResults +from high_parser.types import Context import pytest @@ -93,6 +96,28 @@ def test_invalid_scheme(kerngen_path): assert result.returncode != 0 +def test_parse_results_missing_context(): + """Test ParseResults constructor for missing context""" + with pytest.raises(LookupError) as e: + parse_results = ParseResults([], {}) + print(parse_results.context) # will generate a lookup error + assert "No Context found for commands list for ParseResults" in str(e.value) + + +def test_parse_results_multiple_context(): + """Test ParseResults constructor for multiple context""" + with pytest.raises(LookupError) as e: + parse_results = ParseResults( + [ + Context(scheme="BGV", poly_order=8192, max_rns=1), + Context(scheme="CKKS", poly_order=8192, max_rns=1), + ], + {}, + ) + print(parse_results.context) # will raise a LookupError + assert "Multiple Context found in commands list for ParseResults" in str(e.value) + + @pytest.fixture(name="gen_op_data") def fixture_gen_op_data(request): """Given an op name, return both the input and expected output strings""" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..c789606 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = kerngen