Skip to content

Commit

Permalink
Enhancements to the NTT/INTT kernel to support support generic power …
Browse files Browse the repository at this point in the history
…of 2 polynomials - Tested for 16k, 32k, 64k, and 128k. (#46)

Enhancements to the NTT/INTT kernel to support support generic power of 2 polynomials - Tested for 16k, 32k, 64k, and 128k.

Co-authored-by: Flavio Bergamaschi <[email protected]>
  • Loading branch information
christopherngutierrez and faberga authored Sep 18, 2024
1 parent 7b3a125 commit cec8bb9
Showing 1 changed file with 35 additions and 19 deletions.
54 changes: 35 additions & 19 deletions kerngen/pisa_generators/ntt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

"""Module containing conversions or operations from isa to p-isa."""

Expand All @@ -12,6 +13,15 @@
from .basic import Mul, Muli, mixed_to_pisa_ops


def generate_unit_index(size: int, op: pisa_op.NTT | pisa_op.INTT):
"""Helper to return unit indices for ntt/intt"""
for i in range(int(size / 2)):
if issubclass(op, pisa_op.NTT):
yield (i, int(size / 2) + i, i * 2, i * 2 + 1)
else:
yield (i * 2, i * 2 + 1, i, int(size / 2) + i)


# pylint: disable=too-many-arguments
def butterflies_ops(
op: pisa_op.NTT | pisa_op.INTT,
Expand All @@ -21,42 +31,48 @@ def butterflies_ops(
input0: Polys,
*, # only kwargs after
init_input: bool = False,
unit_size: int = 8192
) -> list[PIsaOp]:
"""Helper to return butterflies pisa operations for NTT/INTT"""
ntt_stages = context.ntt_stages
ntt_stages_div_by_two = ntt_stages % 2

stage_dst_srcs = [
(
(stage, outtmp, output)
if ntt_stages_div_by_two == stage % 2
else (stage, output, outtmp)
)
for stage in range(ntt_stages)
]
ntt_stages_div_by_two = context.ntt_stages % 2

if init_input is True:
# intt
stage_dst_srcs = [
(
(stage, outtmp, output)
if ntt_stages_div_by_two == stage % 2
else (stage, output, outtmp)
)
for stage in range(context.ntt_stages)
]
stage_dst_srcs[0] = (
(0, outtmp, input0) if ntt_stages_div_by_two == 0 else (0, input0, outtmp)
(0, outtmp, input0) if ntt_stages_div_by_two == 0 else (0, output, input0)
)
else:
# ntt
stage_dst_srcs = [
((stage, outtmp, output) if stage % 2 == 0 else (stage, output, outtmp))
for stage in range(context.ntt_stages)
]

return [
op(
context.label,
dst(part, q, unit),
dst(part, q, next_unit),
src(part, q, unit),
src(part, q, next_unit),
dst(part, q, unit[0]),
dst(part, q, unit[1]),
src(part, q, unit[2]),
src(part, q, unit[3]),
stage,
unit,
unit[0] if issubclass(op, pisa_op.NTT) else unit[2],
q,
)
# units for omegas (aka w) taken from 16K onwards
for part, (stage, dst, src), q, (unit, next_unit) in it.product(
for part, (stage, dst, src), q, unit in it.product(
range(input0.start_parts, input0.parts),
stage_dst_srcs,
range(input0.start_rns, input0.rns),
it.pairwise(range(context.units)),
generate_unit_index(int(context.poly_order / unit_size), op),
)
]

Expand Down

0 comments on commit cec8bb9

Please sign in to comment.