diff --git a/kerngen/pisa_generators/basic.py b/kerngen/pisa_generators/basic.py index 7771626..e96901b 100644 --- a/kerngen/pisa_generators/basic.py +++ b/kerngen/pisa_generators/basic.py @@ -1,10 +1,14 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + # Copyright (C) 2024 Intel Corporation """Module containing conversions or operations from isa to p-isa.""" import itertools as it from dataclasses import dataclass -from typing import ClassVar, Iterable +from typing import ClassVar, Iterable, Tuple +from string import ascii_letters import high_parser.pisa_operations as pisa_op from high_parser.pisa_operations import PIsaOp @@ -238,3 +242,57 @@ def to_pisa(self) -> list[PIsaOp]: pisa_op.Copy(self.context.label, *expand_io) for expand_io, _ in expand_ios(self.context, self.output, self.input0) ] + + +@dataclass +class KeyMul(HighOp): + """Class representing a key multiplication operation""" + + context: KernelContext + output: Polys + input0: Polys + input1: KeyPolys + input0_fixed_part: int + + def to_pisa(self) -> list[PIsaOp]: + """Return the p-isa code to perform a key multiplication""" + + def get_pisa_op(num): + yield 0, pisa_op.Mul + yield from ((op, pisa_op.Mac) for op in range(1, num)) + + ls: list[pisa_op] = [] + for digit, op in get_pisa_op(self.input1.digits): + input0_tmp = Polys.from_polys(self.input0) + input0_tmp.name += "_" + ascii_letters[digit] + ls.extend( + op( + self.context.label, + self.output(part, q, unit), + input0_tmp(self.input0_fixed_part, q, unit), + self.input1(digit, part, q, unit), + q, + ) + for part, q, unit in it.product( + range(self.input1.start_parts, self.input1.parts), + range(self.input0.start_rns, self.input0.rns), + range(self.context.units), + ) + ) + return ls + + +def extract_last_part_polys(input0: Polys, rns: int) -> Tuple[Polys, Polys, Polys]: + """Split and extract the last part of input0 with a change of rns""" + input_last_part = Polys.from_polys(input0, mode="last_part") + input_last_part.name = input0.name + + last_coeff = Polys.from_polys(input_last_part) + last_coeff.name = "coeffs" + last_coeff.rns = rns + + upto_last_coeffs = Polys.from_polys(last_coeff) + upto_last_coeffs.parts = 1 + upto_last_coeffs.start_parts = 0 + + return input_last_part, last_coeff, upto_last_coeffs diff --git a/kerngen/pisa_generators/decomp.py b/kerngen/pisa_generators/decomp.py new file mode 100644 index 0000000..691f75c --- /dev/null +++ b/kerngen/pisa_generators/decomp.py @@ -0,0 +1,59 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +"""Module containing digit decomposition/base extend""" + +from string import ascii_letters +import itertools as it + +from dataclasses import dataclass +import high_parser.pisa_operations as pisa_op +from high_parser.pisa_operations import PIsaOp +from high_parser import KernelContext, HighOp, Immediate, Polys + +from .basic import Muli, mixed_to_pisa_ops +from .ntt import INTT, NTT + + +@dataclass +class DigitDecompExtend(HighOp): + """Class representing Digit decomposition and base extension""" + + context: KernelContext + output: Polys + input0: Polys + + def to_pisa(self) -> list[PIsaOp]: + """Return the p-isa code performing Digit decomposition followed by + base extension""" + + rns_poly = Polys.from_polys(self.input0) + rns_poly.name = "ct" + + one = Immediate(name="one") + r2 = Immediate(name="R2", rns=self.context.key_rns) + + ls: list[pisa_op] = [] + for input_rns_index in range(self.input0.rns): + ls.extend( + pisa_op.Muli( + self.context.label, + self.output(part, pq, unit), + rns_poly(part, input_rns_index, unit), + r2(part, pq, unit), + pq, + ) + for part, pq, unit in it.product( + range(self.input0.start_parts, self.input0.parts), + range(self.context.key_rns), + range(self.context.units), + ) + ) + output_tmp = Polys.from_polys(self.output) + output_tmp.name += "_" + ascii_letters[input_rns_index] + ls.extend(NTT(self.context, output_tmp, self.output).to_pisa()) + + return mixed_to_pisa_ops( + INTT(self.context, rns_poly, self.input0), + Muli(self.context, rns_poly, rns_poly, one), + ls, + ) diff --git a/kerngen/pisa_generators/relin.py b/kerngen/pisa_generators/relin.py index 693fbd4..d3b27a2 100644 --- a/kerngen/pisa_generators/relin.py +++ b/kerngen/pisa_generators/relin.py @@ -1,101 +1,13 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -"""Module containing relin, keymul, etc.""" - +"""Module containing relin.""" from dataclasses import dataclass -from itertools import product -from string import ascii_letters - -import high_parser.pisa_operations as pisa_op from high_parser.pisa_operations import PIsaOp, Comment -from high_parser import KernelContext, HighOp, Immediate, KeyPolys, Polys - -from .basic import Add, Muli, mixed_to_pisa_ops +from high_parser import KernelContext, HighOp, KeyPolys, Polys +from .basic import Add, KeyMul, mixed_to_pisa_ops, extract_last_part_polys from .mod import Mod -from .ntt import INTT, NTT - - -@dataclass -class KeyMul(HighOp): - """Class representing a key multiplication operation""" - - context: KernelContext - output: Polys - input0: Polys - input1: KeyPolys - - def to_pisa(self) -> list[PIsaOp]: - """Return the p-isa code to perform a key multiplication""" - - def get_pisa_op(num): - yield 0, pisa_op.Mul - yield from ((op, pisa_op.Mac) for op in range(1, num)) - - ls: list[pisa_op] = [] - for digit, op in get_pisa_op(self.input1.digits): - input0_tmp = Polys.from_polys(self.input0) - input0_tmp.name += "_" + ascii_letters[digit] - ls.extend( - op( - self.context.label, - self.output(part, q, unit), - input0_tmp(2, q, unit), - self.input1(digit, part, q, unit), - q, - ) - for part, q, unit in product( - range(self.input1.start_parts, self.input1.parts), - range(self.input0.start_rns, self.input0.rns), - range(self.context.units), - ) - ) - return ls - - -@dataclass -class DigitDecompExtend(HighOp): - """Class representing Digit decomposition and base extension""" - - context: KernelContext - output: Polys - input0: Polys - - def to_pisa(self) -> list[PIsaOp]: - """Return the p-isa code performing Digit decomposition followed by - base extension""" - - rns_poly = Polys.from_polys(self.input0) - rns_poly.name = "ct" - - one = Immediate(name="one") - r2 = Immediate(name="R2", rns=self.context.key_rns) - - ls: list[pisa_op] = [] - for input_rns_index in range(self.input0.rns): - ls.extend( - pisa_op.Muli( - self.context.label, - self.output(part, pq, unit), - rns_poly(part, input_rns_index, unit), - r2(part, pq, unit), - pq, - ) - for part, pq, unit in product( - range(self.input0.start_parts, self.input0.parts), - range(self.context.key_rns), - range(self.context.units), - ) - ) - output_tmp = Polys.from_polys(self.output) - output_tmp.name += "_" + ascii_letters[input_rns_index] - ls.extend(NTT(self.context, output_tmp, self.output).to_pisa()) - - return mixed_to_pisa_ops( - INTT(self.context, rns_poly, self.input0), - Muli(self.context, rns_poly, rns_poly, one), - ls, - ) +from .decomp import DigitDecompExtend @dataclass @@ -120,15 +32,9 @@ def to_pisa(self) -> list[PIsaOp]: mul_by_rlk = Polys("c2_rlk", parts=2, rns=self.context.key_rns) mul_by_rlk_modded_down = Polys.from_polys(mul_by_rlk) mul_by_rlk_modded_down.rns = self.input0.rns - input_last_part = Polys.from_polys(self.input0, mode="last_part") - input_last_part.name = self.input0.name - - last_coeff = Polys.from_polys(input_last_part) - last_coeff.name = "coeffs" - last_coeff.rns = self.context.key_rns - upto_last_coeffs = Polys.from_polys(last_coeff) - upto_last_coeffs.parts = 1 - upto_last_coeffs.start_parts = 0 + input_last_part, last_coeff, upto_last_coeffs = extract_last_part_polys( + self.input0, self.context.key_rns + ) add_original = Polys.from_polys(mul_by_rlk_modded_down) add_original.name = self.input0.name @@ -138,7 +44,7 @@ def to_pisa(self) -> list[PIsaOp]: Comment("Digit decomposition and extend base from Q to PQ"), DigitDecompExtend(self.context, last_coeff, input_last_part), Comment("Multiply by relin key"), - KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key), + KeyMul(self.context, mul_by_rlk, upto_last_coeffs, relin_key, 2), Comment("Mod switch down to Q"), Mod(self.context, mul_by_rlk_modded_down, mul_by_rlk), Comment("Add to original poly"),