diff --git a/kerngen/high_parser/types.py b/kerngen/high_parser/types.py index c110350..5cca8d7 100644 --- a/kerngen/high_parser/types.py +++ b/kerngen/high_parser/types.py @@ -57,6 +57,9 @@ def from_polys(cls, poly: "Polys", *, mode: str | None = None) -> "Polys": case "last_rns": copy.start_rns = copy.rns - 1 return cls(**vars(copy)) + case "single_rns": + copy.rns = 1 + return cls(**vars(copy)) case "last_part": copy.start_parts = copy.parts - 1 return cls(**vars(copy)) diff --git a/kerngen/pisa_generators/mod.py b/kerngen/pisa_generators/mod.py index f99bf20..d249663 100644 --- a/kerngen/pisa_generators/mod.py +++ b/kerngen/pisa_generators/mod.py @@ -123,7 +123,7 @@ def generate_mod_stages() -> list[Stage]: temp_input_last_rns, p_half, Polys.from_polys( - input_remaining_rns, mode="drop_last_rns" + input_remaining_rns, mode="single_rns" ), last_q, ), diff --git a/kerngen/pisa_generators/rescale.py b/kerngen/pisa_generators/rescale.py index 951de89..3bfe56b 100644 --- a/kerngen/pisa_generators/rescale.py +++ b/kerngen/pisa_generators/rescale.py @@ -62,7 +62,7 @@ def to_pisa(self) -> list[PIsaOp]: temp_input_last_rns, temp_input_last_rns, q_last_half, - input_remaining_rns, + Polys.from_polys(input_remaining_rns, mode="single_rns"), last_q, ), Comment("Subtract q_i (last half/last rns) from y"),