Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better floating point support #816

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,10 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
Influence of SSIM on total loss from 0 to 1, ```0.2``` by default.
#### --percent_dense
Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default.
#### --data_dtype
The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.
#### --store_images_as_uint8
Flag that describes how to store images in memory. If set, the images will be stored as uint8, and will be converted to the target data type on demand.

</details>
<br>
Expand Down
2 changes: 2 additions & 0 deletions arguments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(self, parser, sentinel=False):
self._resolution = -1
self._white_background = False
self.data_device = "cuda"
self.data_dtype = "float32"
self.store_images_as_uint8 = False
self.eval = False
super().__init__(parser, "Loading Parameters", sentinel)

Expand Down
4 changes: 4 additions & 0 deletions gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
rotations = rotations,
cov3D_precomp = cov3D_precomp)

# after rasterization, we convert the resulting image to the target dtype
# The rasterizer expects parameters as float32, so the result is also float32.
rendered_image = rendered_image.to(viewpoint_camera.data_dtype)

# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return {"render": rendered_image,
Expand Down
39 changes: 31 additions & 8 deletions scene/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32,
store_images_as_uint8=True,
):
super(Camera, self).__init__()

Expand All @@ -28,6 +29,7 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
self.FoVx = FoVx
self.FoVy = FoVy
self.image_name = image_name
self.data_dtype = data_dtype

try:
self.data_device = torch.device(data_device)
Expand All @@ -36,14 +38,18 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")

self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]
self.store_images_as_uint8 = store_images_as_uint8

if gt_alpha_mask is not None:
self.original_image *= gt_alpha_mask.to(self.data_device)
else:
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
self._original_image = image.to(self.data_device)
self._gt_alpha_mask = gt_alpha_mask
if self._gt_alpha_mask is not None:
self._gt_alpha_mask = self._gt_alpha_mask.to(self.data_device)

if not store_images_as_uint8:
self._original_image = self.convert_image(self._original_image)

self.image_width = self._original_image.shape[2]
self.image_height = self._original_image.shape[1]

self.zfar = 100.0
self.znear = 0.01
Expand All @@ -56,6 +62,23 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]

def convert_image(self, image):
image = (image / 255.0).clamp(0.0, 1.0).to(self.data_dtype)
gt_alpha_mask = self._gt_alpha_mask

if gt_alpha_mask is not None:
gt_alpha_mask = gt_alpha_mask / 255.0
image *= gt_alpha_mask.to(self.data_dtype)

return image

@property
def original_image(self):
if self.store_images_as_uint8:
return self.convert_image(self._original_image)
else:
return self._original_image

class MiniCam:
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
self.image_width = width
Expand Down
6 changes: 4 additions & 2 deletions utils/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from scene.cameras import Camera
import numpy as np
from utils.general_utils import PILtoTorch
from utils.general_utils import PILtoTorch, get_data_dtype
from utils.graphics_utils import fov2focal

WARNED = False
Expand Down Expand Up @@ -49,7 +49,9 @@ def loadCam(args, id, cam_info, resolution_scale):
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
image_name=cam_info.image_name, uid=id, data_device=args.data_device,
data_dtype=get_data_dtype(args.data_dtype),
store_images_as_uint8=args.store_images_as_uint8)

def cameraList_from_camInfos(cam_infos, resolution_scale, args):
camera_list = []
Expand Down
11 changes: 10 additions & 1 deletion utils/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def inverse_sigmoid(x):

def PILtoTorch(pil_image, resolution):
resized_image_PIL = pil_image.resize(resolution)
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
resized_image = torch.from_numpy(np.array(resized_image_PIL))
if len(resized_image.shape) == 3:
return resized_image.permute(2, 0, 1)
else:
Expand Down Expand Up @@ -131,3 +131,12 @@ def flush(self):
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(torch.device("cuda:0"))

def get_data_dtype(dtype):
if dtype == "float32":
return torch.float32
elif dtype == "float64":
return torch.float64
elif dtype == "float16":
return torch.float16
return torch.float32