Skip to content

Commit

Permalink
added support for setting floating point range
Browse files Browse the repository at this point in the history
Users may want to reduce their memory consumption by using fp16.
However, in my tests, such attempts will result in lower quality renders.
Some data type conversions did not have any impact, so I removed them completely.
  • Loading branch information
PerhapsS44 authored and Matei Barbu committed May 15, 2024
1 parent 472689c commit 18eb6d6
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 29 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ python train.py -s <path to COLMAP or NeRF Synthetic dataset>
Influence of SSIM on total loss from 0 to 1, ```0.2``` by default.
#### --percent_dense
Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default.
#### --data_dtype
The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.

</details>
<br>
Expand Down
1 change: 1 addition & 0 deletions arguments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, parser, sentinel=False):
self._resolution = -1
self._white_background = False
self.data_device = "cuda"
self.data_dtype = "float32"
self.eval = False
super().__init__(parser, "Loading Parameters", sentinel)

Expand Down
4 changes: 4 additions & 0 deletions gaussian_renderer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor,
rotations = rotations,
cov3D_precomp = cov3D_precomp)

# after rasterization, we convert the resulting image to the target dtype
# The rasterizer expects parameters as float32, so the result is also float32.
rendered_image = rendered_image.to(viewpoint_camera.original_image.dtype)

# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return {"render": rendered_image,
Expand Down
10 changes: 6 additions & 4 deletions render.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))

def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, dtype=torch.float32):
with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree)
gaussians = GaussianModel(dataset.sh_degree, dtype=dtype)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)

bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
background = torch.tensor(bg_color, dtype=dtype, device="cuda")

if not skip_train:
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
Expand All @@ -62,5 +62,7 @@ def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParam

# Initialize system state (RNG)
safe_state(args.quiet)

dtype = torch.float32

render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, dtype)
7 changes: 4 additions & 3 deletions scene/cameras.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32
):
super(Camera, self).__init__()

Expand All @@ -28,6 +28,7 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
self.FoVx = FoVx
self.FoVy = FoVy
self.image_name = image_name
self.data_dtype = data_dtype

try:
self.data_device = torch.device(data_device)
Expand All @@ -36,12 +37,12 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")

self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
self.original_image = image.clamp(0.0, 1.0).to(self.data_dtype).to(self.data_device)
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]

if gt_alpha_mask is not None:
self.original_image *= gt_alpha_mask.to(self.data_device)
self.original_image *= gt_alpha_mask.to(self.data_dtype).to(self.data_device)
else:
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)

Expand Down
27 changes: 14 additions & 13 deletions scene/gaussian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@

class GaussianModel:

def setup_functions(self):
def setup_functions(self, dtype):
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
L = build_scaling_rotation(scaling_modifier * scaling, rotation, dtype)
actual_covariance = L @ L.transpose(1, 2)
symm = strip_symmetric(actual_covariance)
symm = strip_symmetric(actual_covariance, dtype)
return symm

self.scaling_activation = torch.exp
Expand All @@ -41,7 +41,7 @@ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
self.rotation_activation = torch.nn.functional.normalize


def __init__(self, sh_degree : int):
def __init__(self, sh_degree : int, dtype=torch.float32):
self.active_sh_degree = 0
self.max_sh_degree = sh_degree
self._xyz = torch.empty(0)
Expand All @@ -56,7 +56,8 @@ def __init__(self, sh_degree : int):
self.optimizer = None
self.percent_dense = 0
self.spatial_lr_scale = 0
self.setup_functions()
self.dtype = dtype
self.setup_functions(dtype)

def capture(self):
return (
Expand Down Expand Up @@ -136,15 +137,15 @@ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
rots[:, 0] = 1

opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=self.dtype, device="cuda"))

self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
self._scaling = nn.Parameter(scales.requires_grad_(True))
self._rotation = nn.Parameter(rots.requires_grad_(True))
self._opacity = nn.Parameter(opacities.requires_grad_(True))
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda", dtype=self.dtype)

def training_setup(self, training_args):
self.percent_dense = training_args.percent_dense
Expand Down Expand Up @@ -246,12 +247,12 @@ def load_ply(self, path):
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])

self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=self.dtype, device="cuda").requires_grad_(True))
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=self.dtype, device="cuda").requires_grad_(True))
self._scaling = nn.Parameter(torch.tensor(scales, dtype=self.dtype, device="cuda").requires_grad_(True))
self._rotation = nn.Parameter(torch.tensor(rots, dtype=self.dtype, device="cuda").requires_grad_(True))

self.active_sh_degree = self.max_sh_degree

Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from gaussian_renderer import render, network_gui
import sys
from scene import Scene, GaussianModel
from utils.general_utils import safe_state
from utils.general_utils import get_data_dtype, safe_state
import uuid
from tqdm import tqdm
from utils.image_utils import psnr
Expand Down Expand Up @@ -216,6 +216,7 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i
# Start GUI server, configure and run training
network_gui.init(args.ip, args.port)
torch.autograd.set_detect_anomaly(args.detect_anomaly)

training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from)

# All done
Expand Down
7 changes: 5 additions & 2 deletions utils/camera_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from scene.cameras import Camera
import numpy as np
from utils.general_utils import PILtoTorch
from utils.general_utils import PILtoTorch, get_data_dtype
from utils.graphics_utils import fov2focal

WARNED = False
Expand Down Expand Up @@ -39,6 +39,8 @@ def loadCam(args, id, cam_info, resolution_scale):
resolution = (int(orig_w / scale), int(orig_h / scale))

resized_image_rgb = PILtoTorch(cam_info.image, resolution)

# resized_image_rgb = resized_image_rgb.to(get_data_dtype(args.data_dtype))

gt_image = resized_image_rgb[:3, ...]
loaded_mask = None
Expand All @@ -49,7 +51,8 @@ def loadCam(args, id, cam_info, resolution_scale):
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
image=gt_image, gt_alpha_mask=loaded_mask,
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
image_name=cam_info.image_name, uid=id, data_device=args.data_device,
data_dtype=get_data_dtype(args.data_dtype))

def cameraList_from_camInfos(cam_infos, resolution_scale, args):
camera_list = []
Expand Down
21 changes: 15 additions & 6 deletions utils/general_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def helper(step):

return helper

def strip_lowerdiag(L):
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
def strip_lowerdiag(L, dtype=torch.float32):
uncertainty = torch.zeros((L.shape[0], 6), dtype=dtype, device="cuda")

uncertainty[:, 0] = L[:, 0, 0]
uncertainty[:, 1] = L[:, 0, 1]
Expand All @@ -72,8 +72,8 @@ def strip_lowerdiag(L):
uncertainty[:, 5] = L[:, 2, 2]
return uncertainty

def strip_symmetric(sym):
return strip_lowerdiag(sym)
def strip_symmetric(sym, dtype=torch.float32):
return strip_lowerdiag(sym, dtype=dtype)

def build_rotation(r):
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
Expand All @@ -98,8 +98,8 @@ def build_rotation(r):
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
return R

def build_scaling_rotation(s, r):
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
def build_scaling_rotation(s, r, dtype=torch.float32):
L = torch.zeros((s.shape[0], 3, 3), dtype=dtype, device="cuda")
R = build_rotation(r)

L[:,0,0] = s[:,0]
Expand Down Expand Up @@ -131,3 +131,12 @@ def flush(self):
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(torch.device("cuda:0"))

def get_data_dtype(dtype):
if dtype == "float32":
return torch.float32
elif dtype == "float64":
return torch.float64
elif dtype == "float16":
return torch.float16
return torch.float32

This comment has been minimized.

Copy link
@mateibarbu19

mateibarbu19 May 15, 2024

Should you throw a error?

0 comments on commit 18eb6d6

Please sign in to comment.