diff --git a/kerngen/pisa_generators/relin.py b/kerngen/pisa_generators/relin.py index ff0cb00..5bdcc27 100644 --- a/kerngen/pisa_generators/relin.py +++ b/kerngen/pisa_generators/relin.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from itertools import product from string import ascii_letters +from typing import Tuple import high_parser.pisa_operations as pisa_op from high_parser.pisa_operations import PIsaOp, Comment @@ -16,6 +17,22 @@ from .ntt import INTT, NTT +def init_common_polys(input0: Polys, rns: int) -> Tuple[Polys, Polys, Polys]: + """Initialize commonly used polys in both relin and rotate kernels""" + input_last_part = Polys.from_polys(input, mode="last_part") + input_last_part.name = input0.name + + last_coeff = Polys.from_polys(input_last_part) + last_coeff.name = "coeffs" + last_coeff.rns = rns + + upto_last_coeffs = Polys.from_polys(last_coeff) + upto_last_coeffs.parts = 1 + upto_last_coeffs.start_parts = 0 + + return input_last_part, last_coeff, upto_last_coeffs + + @dataclass class KeyMul(HighOp): """Class representing a key multiplication operation""" @@ -121,15 +138,10 @@ def to_pisa(self) -> list[PIsaOp]: mul_by_rlk = Polys("c2_rlk", parts=2, rns=self.context.key_rns) mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk) mul_by_rlk_modded_down.rns = self.input0.rns - input_last_part = Polys.from_polys(self.input0, mode="last_part") - input_last_part.name = self.input0.name - - last_coeff = Polys.from_polys(input_last_part) - last_coeff.name = "coeffs" - last_coeff.rns = self.context.key_rns - upto_last_coeffs = Polys.from_polys(last_coeff) - upto_last_coeffs.parts = 1 - upto_last_coeffs.start_parts = 0 + + input_last_part, last_coeff, upto_last_coeffs = init_common_polys( + self.input0, self.context.key_rns + ) add_original = Polys.from_polys(mul_by_rlk_modded_down) add_original.name = self.input0.name diff --git a/kerngen/pisa_generators/rotate.py b/kerngen/pisa_generators/rotate.py index 23db077..69de82a 100644 --- a/kerngen/pisa_generators/rotate.py +++ b/kerngen/pisa_generators/rotate.py @@ -9,7 +9,7 @@ from high_parser import KernelContext, HighOp, Polys, KeyPolys from .basic import Add, mixed_to_pisa_ops -from .relin import KeyMul, DigitDecompExtend +from .relin import KeyMul, DigitDecompExtend, init_common_polys from .mod import Mod from .ntt import INTT, NTT @@ -23,7 +23,7 @@ class Rotate(HighOp): input0: Polys def to_pisa(self) -> list[PIsaOp]: - """Return the p-isa code to perform a relinearization (relin). Note: + """Return the p-isa code to perform a rotate. Note: currently only supports polynomials with two parts. Currently only supports number of digits equal to the RNS size""" self.output.parts = 2 @@ -31,22 +31,14 @@ def to_pisa(self) -> list[PIsaOp]: relin_key = KeyPolys( "gk", parts=2, rns=self.context.key_rns, digits=self.input0.rns ) - # pylint: disable=duplicate-code mul_by_rlk = Polys("c2_gk", parts=2, rns=self.context.key_rns) mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk) mul_by_rlk_modded_down.rns = self.input0.rns mul_by_rlk_modded_down.name = self.output.name - input_last_part = Polys.from_polys(self.input0, mode="last_part") - input_last_part.name = self.input0.name - - last_coeff = Polys.from_polys(input_last_part) - last_coeff.name = "coeffs" - last_coeff.rns = self.context.key_rns - - upto_last_coeffs = Polys.from_polys(last_coeff) - upto_last_coeffs.parts = 1 - upto_last_coeffs.start_parts = 0 + input_last_part, last_coeff, upto_last_coeffs = init_common_polys( + self.input0, self.context.key_rns + ) cd = Polys.from_polys(self.input0) cd.name = "cd" @@ -54,13 +46,12 @@ def to_pisa(self) -> list[PIsaOp]: cd.start_parts = 0 start_input = Polys.from_polys(self.input0) - start_input.start_parts = 0 start_input.parts = 1 + start_input.start_parts = 0 first_part_rlk = Polys.from_polys(mul_by_rlk_modded_down) first_part_rlk.parts = 1 first_part_rlk.start_parts = 0 - # pylint: enable=duplicate-code return mixed_to_pisa_ops( Comment( @@ -75,5 +66,5 @@ def to_pisa(self) -> list[PIsaOp]: INTT(self.context, cd, start_input), NTT(self.context, cd, cd), Add(self.context, self.output, cd, first_part_rlk), - # Comment("End of rotate kernel") + Comment("End of rotate kernel"), )