diff --git a/diff_gaussian_rasterization/__init__.py b/diff_gaussian_rasterization/__init__.py index bbef37d1..fbd8601c 100644 --- a/diff_gaussian_rasterization/__init__.py +++ b/diff_gaussian_rasterization/__init__.py @@ -95,10 +95,24 @@ def forward( ctx.raster_settings = raster_settings ctx.num_rendered = num_rendered ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer) - return color, radii + + accumulation = None + if raster_settings.return_accumulation: + alignment = 128 + offset = (alignment - imgBuffer.data_ptr()) % alignment + total_size = raster_settings.image_height * raster_settings.image_width * 4 + accumulation = ( + imgBuffer[offset: offset + total_size] + .view(torch.float32) + .clone() + .mul_(-1) + .add_(1) + .view((raster_settings.image_height, raster_settings.image_width)) + ) + return color, radii, accumulation @staticmethod - def backward(ctx, grad_out_color, _): + def backward(ctx, grad_out_color, _1, _2): # Restore necessary values from context num_rendered = ctx.num_rendered @@ -167,6 +181,7 @@ class GaussianRasterizationSettings(NamedTuple): campos : torch.Tensor prefiltered : bool debug : bool + return_accumulation : bool class GaussianRasterizer(nn.Module): def __init__(self, raster_settings):