Skip to content

Commit

Permalink
eliminated duplicated code. Created a helper function in relin to ini…
Browse files Browse the repository at this point in the history
…tilize common polys for relin/rotate.
  • Loading branch information
christopherngutierrez committed Sep 13, 2024
1 parent ff50a0a commit 70a7a80
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
30 changes: 21 additions & 9 deletions kerngen/pisa_generators/relin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dataclasses import dataclass
from itertools import product
from string import ascii_letters
from typing import Tuple

import high_parser.pisa_operations as pisa_op
from high_parser.pisa_operations import PIsaOp, Comment
Expand All @@ -16,6 +17,22 @@
from .ntt import INTT, NTT


def init_common_polys(input0: Polys, rns: int) -> Tuple[Polys, Polys, Polys]:
"""Initialize commonly used polys in both relin and rotate kernels"""
input_last_part = Polys.from_polys(input, mode="last_part")
input_last_part.name = input0.name

last_coeff = Polys.from_polys(input_last_part)
last_coeff.name = "coeffs"
last_coeff.rns = rns

upto_last_coeffs = Polys.from_polys(last_coeff)
upto_last_coeffs.parts = 1
upto_last_coeffs.start_parts = 0

return input_last_part, last_coeff, upto_last_coeffs


@dataclass
class KeyMul(HighOp):
"""Class representing a key multiplication operation"""
Expand Down Expand Up @@ -121,15 +138,10 @@ def to_pisa(self) -> list[PIsaOp]:
mul_by_rlk = Polys("c2_rlk", parts=2, rns=self.context.key_rns)
mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk)
mul_by_rlk_modded_down.rns = self.input0.rns
input_last_part = Polys.from_polys(self.input0, mode="last_part")
input_last_part.name = self.input0.name

last_coeff = Polys.from_polys(input_last_part)
last_coeff.name = "coeffs"
last_coeff.rns = self.context.key_rns
upto_last_coeffs = Polys.from_polys(last_coeff)
upto_last_coeffs.parts = 1
upto_last_coeffs.start_parts = 0

input_last_part, last_coeff, upto_last_coeffs = init_common_polys(
self.input0, self.context.key_rns
)

add_original = Polys.from_polys(mul_by_rlk_modded_down)
add_original.name = self.input0.name
Expand Down
23 changes: 7 additions & 16 deletions kerngen/pisa_generators/rotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from high_parser import KernelContext, HighOp, Polys, KeyPolys

from .basic import Add, mixed_to_pisa_ops
from .relin import KeyMul, DigitDecompExtend
from .relin import KeyMul, DigitDecompExtend, init_common_polys
from .mod import Mod
from .ntt import INTT, NTT

Expand All @@ -23,44 +23,35 @@ class Rotate(HighOp):
input0: Polys

def to_pisa(self) -> list[PIsaOp]:
"""Return the p-isa code to perform a relinearization (relin). Note:
"""Return the p-isa code to perform a rotate. Note:
currently only supports polynomials with two parts. Currently only
supports number of digits equal to the RNS size"""
self.output.parts = 2
self.input0.parts = 2
relin_key = KeyPolys(
"gk", parts=2, rns=self.context.key_rns, digits=self.input0.rns
)
# pylint: disable=duplicate-code
mul_by_rlk = Polys("c2_gk", parts=2, rns=self.context.key_rns)
mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk)
mul_by_rlk_modded_down.rns = self.input0.rns
mul_by_rlk_modded_down.name = self.output.name

input_last_part = Polys.from_polys(self.input0, mode="last_part")
input_last_part.name = self.input0.name

last_coeff = Polys.from_polys(input_last_part)
last_coeff.name = "coeffs"
last_coeff.rns = self.context.key_rns

upto_last_coeffs = Polys.from_polys(last_coeff)
upto_last_coeffs.parts = 1
upto_last_coeffs.start_parts = 0
input_last_part, last_coeff, upto_last_coeffs = init_common_polys(
self.input0, self.context.key_rns
)

cd = Polys.from_polys(self.input0)
cd.name = "cd"
cd.parts = 1
cd.start_parts = 0

start_input = Polys.from_polys(self.input0)
start_input.start_parts = 0
start_input.parts = 1
start_input.start_parts = 0

first_part_rlk = Polys.from_polys(mul_by_rlk_modded_down)
first_part_rlk.parts = 1
first_part_rlk.start_parts = 0
# pylint: enable=duplicate-code

return mixed_to_pisa_ops(
Comment(
Expand All @@ -75,5 +66,5 @@ def to_pisa(self) -> list[PIsaOp]:
INTT(self.context, cd, start_input),
NTT(self.context, cd, cd),
Add(self.context, self.output, cd, first_part_rlk),
# Comment("End of rotate kernel")
Comment("End of rotate kernel"),
)

0 comments on commit 70a7a80

Please sign in to comment.