Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BpOsdDecoder in v2 is slower #35

Closed
inmzhang opened this issue Mar 21, 2024 · 2 comments
Closed

BpOsdDecoder in v2 is slower #35

inmzhang opened this issue Mar 21, 2024 · 2 comments

Comments

@inmzhang
Copy link

inmzhang commented Mar 21, 2024

To test BpOsdDecoder in ldpc_v2 branch, I sampled and decoded some surface code circuits with sinter.

For the baseline, I used stimbposd.sinter_decoders()(which internally used ldpc.bposd_decoder in v1) to decode the samplings. For v2, awared of #32 , I added the SinterCompiledBpOsdDecoder as below to use with sinter:

import stim
import numpy as np
import pathlib
from ldpc.bposd_decoder import BpOsdDecoder
import sinter
from beliefmatching import detector_error_model_to_check_matrices


MAX_BP_ITERS = 30
BP_METHOD = "ps"
OSD_ORDER = 60


class SinterCompiledBpOsdDecoder(sinter.CompiledDecoder):
    def __init__(self, bposd: BpOsdDecoder, num_dets: int, observables_matrix: np.ndarray):
        self.bposd = bposd
        self.num_dets = num_dets
        self.observables_matrix = observables_matrix

    def decode_shots_bit_packed(self, *, bit_packed_detection_event_data: np.ndarray) -> np.ndarray:
        shots = np.unpackbits(bit_packed_detection_event_data, axis=1, count=self.num_dets, bitorder="little")
        corrs = np.apply_along_axis(self.bposd.decode, axis=1, arr=shots)
        predictions = (corrs @ self.observables_matrix.T) % 2
        return np.packbits(predictions, axis=1, bitorder="little")


class SinterBpOsdDecoder(sinter.Decoder):
    def __init__(
        self,
        max_iter=0,
        bp_method="ms",
        ms_scaling_factor=0.625,
        schedule="parallel",
        omp_thread_count=1,
        serial_schedule_order=None,
        osd_method="osd0",
        osd_order=0,
    ):
        self.max_iter = max_iter
        self.bp_method = bp_method
        self.ms_scaling_factor = ms_scaling_factor
        self.schedule = schedule
        self.omp_thread_count = omp_thread_count
        self.serial_schedule_order = serial_schedule_order
        self.osd_method = osd_method
        self.osd_order = osd_order

    def compile_decoder_for_dem(
        self,
        *,
        dem: stim.DetectorErrorModel,
    ) -> sinter.CompiledDecoder:
        check_matrices = detector_error_model_to_check_matrices(dem)
        bposd = BpOsdDecoder(
            check_matrices.check_matrix,
            error_channel=list(check_matrices.priors),
            max_iter=self.max_iter,
            bp_method=self.bp_method,
            ms_scaling_factor=self.ms_scaling_factor,
            schedule=self.schedule,
            omp_thread_count=self.omp_thread_count,
            serial_schedule_order=self.serial_schedule_order,
            osd_method=self.osd_method,
            osd_order=self.osd_order,
        )
        return SinterCompiledBpOsdDecoder(
            bposd, num_dets=dem.num_detectors, observables_matrix=check_matrices.observables_matrix
        )

    def decode_via_files(
        self,
        *,
        num_shots: int,
        num_dets: int,
        num_obs: int,
        dem_path: pathlib.Path,
        dets_b8_in_path: pathlib.Path,
        obs_predictions_b8_out_path: pathlib.Path,
        tmp_dir: pathlib.Path,
    ) -> None:
        self.dem = stim.DetectorErrorModel.from_file(dem_path)
        self.matrices = detector_error_model_to_check_matrices(self.dem)
        self.bposd = BpOsdDecoder(
            self.matrices.check_matrix,
            error_channel=list(self.matrices.priors),
            max_iter=self.max_iter,
            bp_method=self.bp_method,
            ms_scaling_factor=self.ms_scaling_factor,
            schedule=self.schedule,
            omp_thread_count=self.omp_thread_count,
            serial_schedule_order=self.serial_schedule_order,
            osd_method=self.osd_method,
            osd_order=self.osd_order,
        )

        shots = stim.read_shot_data_file(path=dets_b8_in_path, format="b8", num_detectors=num_dets)
        predictions = np.zeros((num_shots, num_obs), dtype=bool)
        for i in range(num_shots):
            predictions[i, :] = self.decode(shots[i, :])

        stim.write_shot_data_file(
            data=predictions,
            path=obs_predictions_b8_out_path,
            format="b8",
            num_observables=num_obs,
        )

    def decode(self, syndrome: np.ndarray) -> np.ndarray:
        corr = self.bposd.decode(syndrome)
        return (self.matrices.observables_matrix @ corr) % 2


def bposd_sinter_decoder():
    return {
        "bposdv2": SinterBpOsdDecoder(
            max_iter=MAX_BP_ITERS,
            bp_method=BP_METHOD,
            osd_method="OSD_CS",
            osd_order=OSD_ORDER,
        )
    }

The MAX_BP_ITERS/BP_METHOD/OSD_ORDER were selected to be consistent with the default ones in stimbposd.

Then, I ran the sampling with the following script:

import stim
import sinter
import stimbposd

import bposdtest # where I defined the sinter decoder for bposd v2

def generate_example_tasks():
    for p in [0.004, 0.006, 0.008, 0.01, 0.012]:
        for d in [3, 5, 7]:
            circuit = stim.Circuit.generated(
                code_task="surface_code:rotated_memory_x",
                rounds=d,
                distance=d,
                after_clifford_depolarization=p,
                after_reset_flip_probability=p,
                before_measure_flip_probability=p,
                before_round_data_depolarization=p,
            )
            yield sinter.Task(
                circuit=circuit,
                json_metadata={"p": p, "d": d},
            )

if __name__ == "__main__":
    sinter.collect(
        num_workers=120, # I ran the simulation on a server
        max_shots=500_000,
        max_errors=500,
        tasks=generate_example_tasks(),
        # decoders=["bposd"],
        # custom_decoders=stimbposd.sinter_decoders(),
        decoders=["bposd2"],
        custom_decoders=bposdtest.bposd_sinter_decoder(),
        print_progress=True,
        save_resume_filepath="bposd.csv",
    )

I compared the decoding accuracy and time per shot with v1/v2 implementations, and here is the result:

图片 图片

We can clearly see for the surface code with distance > 3, the decoding time of v2 implementation is slower than v1. I'm not sure whether all the conditions/arguments were equally set in this benchmark for v1 and v2 and I might missed something important.

The environment I used for the test is Ubuntu22.04LTS and Python3.11.4.

@quantumgizmos
Copy link
Owner

Hi. Could you try setting the bp_method='ps_log' for the simulations in v1? I think the speed of the two versions should be more comparable once this change has been made.

The bp_method='ps' method uses a version of BP where the messages are computed using probability ratios, rather than log probability ratios. This is quicker, as you can avoid computing the archtanh(x) function need for the LLR version. I haven't implemented the optimised version of product sum for V2, but it is on my todo list.

Cheers,
Joschka

@inmzhang
Copy link
Author

Thanks for your reply. After I set bp_method='ps_log' for v1, the speed of the two version is close.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants