diff --git a/kerngen/high_parser/types.py b/kerngen/high_parser/types.py index 53eacee..786b437 100644 --- a/kerngen/high_parser/types.py +++ b/kerngen/high_parser/types.py @@ -74,10 +74,10 @@ def __init__(self, *args, **kwargs): self.digits = kwargs.get(digits, 1) super().__init__(*args, **{k: v for k, v in kwargs.items() if k != digits}) - # def expand(self, part: int, digit: int, q: int, unit: int) -> str: + # def expand(self, digit: int, part: int, q: int, unit: int) -> str: def expand(self, *args) -> str: """Returns a string of the expanded symbol w.r.t. digit, rns, part, and unit""" - part, digit, q, unit = args + digit, part, q, unit = args # Sanity bounds checks if ( self.start_parts > part >= self.parts diff --git a/kerngen/pisa_generators/relin.py b/kerngen/pisa_generators/relin.py index bf7d189..ff0cb00 100644 --- a/kerngen/pisa_generators/relin.py +++ b/kerngen/pisa_generators/relin.py @@ -24,6 +24,7 @@ class KeyMul(HighOp): output: Polys input0: Polys input1: KeyPolys + input0_part: int def to_pisa(self) -> list[PIsaOp]: """Return the p-isa code to perform a key multiplication""" @@ -40,7 +41,7 @@ def get_pisa_op(num): op( self.context.label, self.output(part, q, unit), - input0_tmp(2, q, unit), + input0_tmp(self.input0_part, q, unit), self.input1(digit, part, q, unit), q, ) @@ -138,7 +139,7 @@ def to_pisa(self) -> list[PIsaOp]: Comment("Digit decomposition and extend base from Q to PQ"), DigitDecompExtend(self.context, last_coeff, input_last_part), Comment("Multiply by relin key"), - KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key), + KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key, 2), Comment("Mod switch down to Q"), Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk), Comment("Add to original poly"), diff --git a/kerngen/pisa_generators/rotate.py b/kerngen/pisa_generators/rotate.py index 2f8e536..23db077 100644 --- a/kerngen/pisa_generators/rotate.py +++ b/kerngen/pisa_generators/rotate.py @@ -28,31 +28,31 @@ def to_pisa(self) -> list[PIsaOp]: supports number of digits equal to the RNS size""" self.output.parts = 2 self.input0.parts = 2 - relin_key = KeyPolys( - "rlk", parts=2, rns=self.context.key_rns, digits=self.input0.rns + "gk", parts=2, rns=self.context.key_rns, digits=self.input0.rns ) - mul_by_rlk = Polys("c2_rlk", parts=2, rns=self.context.key_rns) + # pylint: disable=duplicate-code + mul_by_rlk = Polys("c2_gk", parts=2, rns=self.context.key_rns) mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk) mul_by_rlk_modded_down.rns = self.input0.rns mul_by_rlk_modded_down.name = self.output.name + input_last_part = Polys.from_polys(self.input0, mode="last_part") input_last_part.name = self.input0.name last_coeff = Polys.from_polys(input_last_part) last_coeff.name = "coeffs" last_coeff.rns = self.context.key_rns + upto_last_coeffs = Polys.from_polys(last_coeff) upto_last_coeffs.parts = 1 upto_last_coeffs.start_parts = 0 - add_original = Polys.from_polys(mul_by_rlk_modded_down) - add_original.name = self.input0.name - cd = Polys.from_polys(self.input0) cd.name = "cd" cd.parts = 1 cd.start_parts = 0 + start_input = Polys.from_polys(self.input0) start_input.start_parts = 0 start_input.parts = 1 @@ -60,13 +60,15 @@ def to_pisa(self) -> list[PIsaOp]: first_part_rlk = Polys.from_polys(mul_by_rlk_modded_down) first_part_rlk.parts = 1 first_part_rlk.start_parts = 0 + # pylint: enable=duplicate-code + return mixed_to_pisa_ops( Comment( "Start of rotate kernel - similar to relin, except missing final add" ), DigitDecompExtend(self.context, last_coeff, input_last_part), Comment("Multiply by rotate key"), - KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key), + KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key, 1), Comment("Mod switch down to Q"), Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk), Comment("Start of new code for rotate"),