Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend Optional Parameters for Context #59

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions kerngen/high_parser/optional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""A module to process optional context parameters"""

from abc import ABC, abstractmethod


class OptionalContext(ABC):
"""Abstract class to hold optional parameters for context"""

op_name: str = ""
op_value = None

@abstractmethod
def validate(self, value):
"""Abstract method, which defines how to valudate a value"""


class OptionalInt(OptionalContext):
"""Holds a key/value pair for optional context parameters of type Int"""

def __init__(self, name: str, min_val: int, max_val: int):
self.min_val = min_val
self.max_val = max_val
self._op_name = name

def validate(self, value: int):
"""Validate numeric options with min/max range"""
if self.min_val < value < self.max_val:
return True
return False

@property
def op_value(self):
"""Get op_value"""
return self._op_value

@op_value.setter
def op_value(self, value: int):
"""Set op_value"""
if self.validate(value):
self._op_value = int(value)
else:
raise ValueError(
"{self.op_name} must be in range ({self.min_val}, {self.max_val}): {self.op_name}={self.op_value}"
)


class OptionalIntMinMax:
"""Holds min/max values for optional context parameters for type Int"""

int_min: int
int_max: int
default: int | None

def __init__(self, int_min: int, int_max: int, default: int | None):
self.int_min = int_min
self.int_max = int_max
self.default = default


class OptionalFactory(ABC):
"""Abstract class that creates OptionaContext objects"""

MAX_KRNS_DELTA = 128
MAX_DIGIT = 3
MIN_KRNS_DELTA = MIN_DIGIT = 0
optionals = {
"krns_delta": OptionalIntMinMax(MIN_KRNS_DELTA, MAX_KRNS_DELTA, 0),
"num_digits": OptionalIntMinMax(MIN_DIGIT, MAX_DIGIT, None),
}

@staticmethod
@abstractmethod
def create(name: str, value) -> OptionalContext:
"""Abstract method, to define how to create an OptionalContext"""


class OptionalIntFactory(OptionalFactory):
"""Optional context parameter factory for Int types"""

@staticmethod
def create(name: str, value: int) -> OptionalInt:
"""Create a OptionalInt object based on key/value pair"""
if name in OptionalIntFactory.optionals:
if isinstance(OptionalIntFactory.optionals[name], OptionalIntMinMax):
optional_int = OptionalInt(
name,
OptionalIntFactory.optionals[name].int_min,
OptionalIntFactory.optionals[name].int_max,
)
optional_int.op_value = value
# add other optional types here
else:
raise KeyError(f"Invalid optional name for Context: '{name}'")
return optional_int


class OptionalFactoryDispatcher:
"""An object dispatcher based on key/value pair for comptional context parameters"""

@staticmethod
def create(name: str, value) -> OptionalContext:
"""Creat an OptionalContext object based on the type of value passed in"""
if value.isnumeric():
value = int(value)
match value:
case int():
return OptionalIntFactory.create(name, value)
case _:
raise ValueError(f"Current type '{type(value)}' is not supported.")


class OptionalsParser:
"""Parses key/value pairs and returns a dictionary of optiona parameters"""

@staticmethod
def __default_values():
default_dict = {}
for key, val in OptionalFactory.optionals.items():
default_dict[key] = val.default
return default_dict

@staticmethod
def parse(optionals: list[str]):
"""Parse the optional parameter list and return a dictionary with values"""
output_dict = OptionalsParser.__default_values()
for option in optionals:
try:
key, value = option.split("=")
output_dict[key] = OptionalFactoryDispatcher.create(key, value).op_value
except ValueError as err:
raise ValueError(
f"Optional variables must be key/value pairs (e.g. krns_delta=1, num_digits=3): '{option}'"
) from err
return output_dict
18 changes: 10 additions & 8 deletions kerngen/high_parser/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from .pisa_operations import PIsaOp

from .optional import OptionalsParser


class PolyOutOfBoundsError(Exception):
"""Exception for Poly attributes being out of bounds"""
Expand Down Expand Up @@ -149,18 +151,15 @@ class Context(BaseModel):
scheme: str
poly_order: int # the N
max_rns: int
# optional vars for context
key_rns: int | None
num_digits: int | None

@classmethod
def from_string(cls, line: str):
"""Construct context from a string"""
scheme, poly_order, max_rns, *optional = line.split()
try:
krns, *rest = optional
except ValueError:
krns = None
if optional != [] and rest != []:
raise ValueError(f"too many parameters for context given: {line}")
scheme, poly_order, max_rns, *optionals = line.split()
optional_dict = OptionalsParser.parse(optionals)
int_poly_order = int(poly_order)
if (
int_poly_order < MIN_POLY_SIZE
Expand All @@ -172,12 +171,15 @@ def from_string(cls, line: str):
)

int_max_rns = int(max_rns)
int_key_rns = int_max_rns + int(krns) if krns else None
int_key_rns = int_max_rns
int_key_rns += optional_dict.pop("krns_delta")

return cls(
scheme=scheme.upper(),
poly_order=int_poly_order,
max_rns=int_max_rns,
key_rns=int_key_rns,
**optional_dict,
)

@property
Expand Down
45 changes: 44 additions & 1 deletion kerngen/tests/test_kerngen.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,49 @@ def test_multiple_contexts(kerngen_path):
assert result.returncode != 0


def test_context_optional_without_key(kerngen_path):
"""Test kerngen raises an exception when more than one context is given"""
input_string = "CONTEXT BGV 16384 4 1\nData a 2\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
)
assert not result.stdout
assert (
"ValueError: Optional variables must be key/value pairs (e.g. krns_delta=1, num_digits=3): '1'"
in result.stderr
)
assert result.returncode != 0


def test_context_unsupported_optional_variable(kerngen_path):
"""Test kerngen raises an exception when more than one context is given"""
input_string = "CONTEXT BGV 16384 4 test=3\nData a 2\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
)
assert not result.stdout
assert "Invalid optional name for Context: 'test'" in result.stderr
assert result.returncode != 0


@pytest.mark.parametrize("invalid", [-1, 256, 0.1, "str"])
def test_context_optional_invalid_values(kerngen_path, invalid):
"""Test kerngen raises an exception if value is out of range for correct key"""
input_string = f"CONTEXT BGV 16384 4 krns_delta={invalid}\nData a 2\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
)
assert not result.stdout
assert (
f"ValueError: Optional variables must be key/value pairs (e.g. krns_delta=1, num_digits=3): 'krns_delta={invalid}'"
in result.stderr
)
assert result.returncode != 0


def test_unrecognised_opname(kerngen_path):
"""Test kerngen raises an exception when receiving an unrecognised
opname"""
Expand Down Expand Up @@ -99,7 +142,7 @@ def test_invalid_scheme(kerngen_path):
@pytest.mark.parametrize("invalid_poly", [16000, 2**12, 2**13, 2**18])
def test_invalid_poly_order(kerngen_path, invalid_poly):
"""Poly order should be powers of two >= 2^14 and <= 2^17"""
input_string = "CONTEXT BGV " + str(invalid_poly) + " 4 2\nADD a b c\n"
input_string = "CONTEXT BGV " + str(invalid_poly) + " 4\nADD a b c\n"
result = execute_process(
[kerngen_path],
data_in=input_string,
Expand Down
Loading