From 18eb6d6a0c363cf1023fd7b74ffd0dd92bdc35cd Mon Sep 17 00:00:00 2001 From: Stefan Saraev Date: Wed, 24 Apr 2024 18:37:06 +0300 Subject: [PATCH 1/3] added support for setting floating point range Users may want to reduce their memory consumption by using fp16. However, in my tests, such attempts will result in lower quality renders. Some data type conversions did not have any impact, so I removed them completely. --- README.md | 2 ++ arguments/__init__.py | 1 + gaussian_renderer/__init__.py | 4 ++++ render.py | 10 ++++++---- scene/cameras.py | 7 ++++--- scene/gaussian_model.py | 27 ++++++++++++++------------- train.py | 3 ++- utils/camera_utils.py | 7 +++++-- utils/general_utils.py | 21 +++++++++++++++------ 9 files changed, 53 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 4cbd3326d..28f2deeee 100644 --- a/README.md +++ b/README.md @@ -194,6 +194,8 @@ python train.py -s Influence of SSIM on total loss from 0 to 1, ```0.2``` by default. #### --percent_dense Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default. + #### --data_dtype + The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default.
diff --git a/arguments/__init__.py b/arguments/__init__.py index 1e13a551e..3cad0b357 100644 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -53,6 +53,7 @@ def __init__(self, parser, sentinel=False): self._resolution = -1 self._white_background = False self.data_device = "cuda" + self.data_dtype = "float32" self.eval = False super().__init__(parser, "Loading Parameters", sentinel) diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index f74e336af..e8af83186 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -92,6 +92,10 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, rotations = rotations, cov3D_precomp = cov3D_precomp) + # after rasterization, we convert the resulting image to the target dtype + # The rasterizer expects parameters as float32, so the result is also float32. + rendered_image = rendered_image.to(viewpoint_camera.original_image.dtype) + # Those Gaussians that were frustum culled or had a radius of 0 were not visible. # They will be excluded from value updates used in the splitting criteria. return {"render": rendered_image, diff --git a/render.py b/render.py index fc6b82de8..70f85cb22 100644 --- a/render.py +++ b/render.py @@ -34,13 +34,13 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) -def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): +def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, dtype=torch.float32): with torch.no_grad(): - gaussians = GaussianModel(dataset.sh_degree) + gaussians = GaussianModel(dataset.sh_degree, dtype=dtype) scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] - background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") + background = torch.tensor(bg_color, dtype=dtype, device="cuda") if not skip_train: render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) @@ -62,5 +62,7 @@ def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParam # Initialize system state (RNG) safe_state(args.quiet) + + dtype = torch.float32 - render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) \ No newline at end of file + render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, dtype) \ No newline at end of file diff --git a/scene/cameras.py b/scene/cameras.py index abf6e5242..0609d0a46 100644 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -17,7 +17,7 @@ class Camera(nn.Module): def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, - trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32 ): super(Camera, self).__init__() @@ -28,6 +28,7 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.FoVx = FoVx self.FoVy = FoVy self.image_name = image_name + self.data_dtype = data_dtype try: self.data_device = torch.device(data_device) @@ -36,12 +37,12 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) self.data_device = torch.device("cuda") - self.original_image = image.clamp(0.0, 1.0).to(self.data_device) + self.original_image = image.clamp(0.0, 1.0).to(self.data_dtype).to(self.data_device) self.image_width = self.original_image.shape[2] self.image_height = self.original_image.shape[1] if gt_alpha_mask is not None: - self.original_image *= gt_alpha_mask.to(self.data_device) + self.original_image *= gt_alpha_mask.to(self.data_dtype).to(self.data_device) else: self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py index 632a1e8e1..f7905588d 100644 --- a/scene/gaussian_model.py +++ b/scene/gaussian_model.py @@ -23,11 +23,11 @@ class GaussianModel: - def setup_functions(self): + def setup_functions(self, dtype): def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): - L = build_scaling_rotation(scaling_modifier * scaling, rotation) + L = build_scaling_rotation(scaling_modifier * scaling, rotation, dtype) actual_covariance = L @ L.transpose(1, 2) - symm = strip_symmetric(actual_covariance) + symm = strip_symmetric(actual_covariance, dtype) return symm self.scaling_activation = torch.exp @@ -41,7 +41,7 @@ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): self.rotation_activation = torch.nn.functional.normalize - def __init__(self, sh_degree : int): + def __init__(self, sh_degree : int, dtype=torch.float32): self.active_sh_degree = 0 self.max_sh_degree = sh_degree self._xyz = torch.empty(0) @@ -56,7 +56,8 @@ def __init__(self, sh_degree : int): self.optimizer = None self.percent_dense = 0 self.spatial_lr_scale = 0 - self.setup_functions() + self.dtype = dtype + self.setup_functions(dtype) def capture(self): return ( @@ -136,7 +137,7 @@ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") rots[:, 0] = 1 - opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) + opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=self.dtype, device="cuda")) self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) @@ -144,7 +145,7 @@ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): self._scaling = nn.Parameter(scales.requires_grad_(True)) self._rotation = nn.Parameter(rots.requires_grad_(True)) self._opacity = nn.Parameter(opacities.requires_grad_(True)) - self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda", dtype=self.dtype) def training_setup(self, training_args): self.percent_dense = training_args.percent_dense @@ -246,12 +247,12 @@ def load_ply(self, path): for idx, attr_name in enumerate(rot_names): rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) - self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) - self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) - self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) - self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) - self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) - self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) + self._xyz = nn.Parameter(torch.tensor(xyz, dtype=self.dtype, device="cuda").requires_grad_(True)) + self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._opacity = nn.Parameter(torch.tensor(opacities, dtype=self.dtype, device="cuda").requires_grad_(True)) + self._scaling = nn.Parameter(torch.tensor(scales, dtype=self.dtype, device="cuda").requires_grad_(True)) + self._rotation = nn.Parameter(torch.tensor(rots, dtype=self.dtype, device="cuda").requires_grad_(True)) self.active_sh_degree = self.max_sh_degree diff --git a/train.py b/train.py index 5d819b348..7435218a9 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ from gaussian_renderer import render, network_gui import sys from scene import Scene, GaussianModel -from utils.general_utils import safe_state +from utils.general_utils import get_data_dtype, safe_state import uuid from tqdm import tqdm from utils.image_utils import psnr @@ -216,6 +216,7 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i # Start GUI server, configure and run training network_gui.init(args.ip, args.port) torch.autograd.set_detect_anomaly(args.detect_anomaly) + training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) # All done diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 1a54d0ada..6e762a5f6 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -11,7 +11,7 @@ from scene.cameras import Camera import numpy as np -from utils.general_utils import PILtoTorch +from utils.general_utils import PILtoTorch, get_data_dtype from utils.graphics_utils import fov2focal WARNED = False @@ -39,6 +39,8 @@ def loadCam(args, id, cam_info, resolution_scale): resolution = (int(orig_w / scale), int(orig_h / scale)) resized_image_rgb = PILtoTorch(cam_info.image, resolution) + + # resized_image_rgb = resized_image_rgb.to(get_data_dtype(args.data_dtype)) gt_image = resized_image_rgb[:3, ...] loaded_mask = None @@ -49,7 +51,8 @@ def loadCam(args, id, cam_info, resolution_scale): return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, FoVx=cam_info.FovX, FoVy=cam_info.FovY, image=gt_image, gt_alpha_mask=loaded_mask, - image_name=cam_info.image_name, uid=id, data_device=args.data_device) + image_name=cam_info.image_name, uid=id, data_device=args.data_device, + data_dtype=get_data_dtype(args.data_dtype)) def cameraList_from_camInfos(cam_infos, resolution_scale, args): camera_list = [] diff --git a/utils/general_utils.py b/utils/general_utils.py index 541c08252..ed7f0a6ed 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -61,8 +61,8 @@ def helper(step): return helper -def strip_lowerdiag(L): - uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") +def strip_lowerdiag(L, dtype=torch.float32): + uncertainty = torch.zeros((L.shape[0], 6), dtype=dtype, device="cuda") uncertainty[:, 0] = L[:, 0, 0] uncertainty[:, 1] = L[:, 0, 1] @@ -72,8 +72,8 @@ def strip_lowerdiag(L): uncertainty[:, 5] = L[:, 2, 2] return uncertainty -def strip_symmetric(sym): - return strip_lowerdiag(sym) +def strip_symmetric(sym, dtype=torch.float32): + return strip_lowerdiag(sym, dtype=dtype) def build_rotation(r): norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) @@ -98,8 +98,8 @@ def build_rotation(r): R[:, 2, 2] = 1 - 2 * (x*x + y*y) return R -def build_scaling_rotation(s, r): - L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") +def build_scaling_rotation(s, r, dtype=torch.float32): + L = torch.zeros((s.shape[0], 3, 3), dtype=dtype, device="cuda") R = build_rotation(r) L[:,0,0] = s[:,0] @@ -131,3 +131,12 @@ def flush(self): np.random.seed(0) torch.manual_seed(0) torch.cuda.set_device(torch.device("cuda:0")) + +def get_data_dtype(dtype): + if dtype == "float32": + return torch.float32 + elif dtype == "float64": + return torch.float64 + elif dtype == "float16": + return torch.float16 + return torch.float32 \ No newline at end of file From b5a5f72eda16752559d07e574db0c50499168d5a Mon Sep 17 00:00:00 2001 From: Stefan Saraev Date: Thu, 2 May 2024 14:48:56 +0300 Subject: [PATCH 2/3] load speedup: refactored image loading Images are now loaded on the target device as uint8s. Then they are converted to the target data type (eg. fp32 or fp16). This speeds up the loading time. Also, users can opt to store the image as uint8 or as target data type. This will further reduce memory usage. --- README.md | 2 ++ arguments/__init__.py | 1 + gaussian_renderer/__init__.py | 2 +- render.py | 10 ++++----- scene/cameras.py | 38 +++++++++++++++++++++++++++-------- utils/camera_utils.py | 3 ++- utils/general_utils.py | 2 +- 7 files changed, 41 insertions(+), 17 deletions(-) diff --git a/README.md b/README.md index 28f2deeee..14278cd79 100644 --- a/README.md +++ b/README.md @@ -196,6 +196,8 @@ python train.py -s Percentage of scene extent (0--1) a point must exceed to be forcibly densified, ```0.01``` by default. #### --data_dtype The data type (float32, float16) in which images are stored when computing the loss. ```float32``` by default. + #### --store_images_as_uint8 + Flag that describes how to store images in memory. If set, the images will be stored as uint8, and will be converted to the target data type on demand.
diff --git a/arguments/__init__.py b/arguments/__init__.py index 3cad0b357..fdfe3fc65 100644 --- a/arguments/__init__.py +++ b/arguments/__init__.py @@ -54,6 +54,7 @@ def __init__(self, parser, sentinel=False): self._white_background = False self.data_device = "cuda" self.data_dtype = "float32" + self.store_images_as_uint8 = False self.eval = False super().__init__(parser, "Loading Parameters", sentinel) diff --git a/gaussian_renderer/__init__.py b/gaussian_renderer/__init__.py index e8af83186..3efab6e03 100644 --- a/gaussian_renderer/__init__.py +++ b/gaussian_renderer/__init__.py @@ -94,7 +94,7 @@ def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, # after rasterization, we convert the resulting image to the target dtype # The rasterizer expects parameters as float32, so the result is also float32. - rendered_image = rendered_image.to(viewpoint_camera.original_image.dtype) + rendered_image = rendered_image.to(viewpoint_camera.data_dtype) # Those Gaussians that were frustum culled or had a radius of 0 were not visible. # They will be excluded from value updates used in the splitting criteria. diff --git a/render.py b/render.py index 70f85cb22..fc54831db 100644 --- a/render.py +++ b/render.py @@ -34,13 +34,13 @@ def render_set(model_path, name, iteration, views, gaussians, pipeline, backgrou torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png")) torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png")) -def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool, dtype=torch.float32): +def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool): with torch.no_grad(): - gaussians = GaussianModel(dataset.sh_degree, dtype=dtype) + gaussians = GaussianModel(dataset.sh_degree) scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False) bg_color = [1,1,1] if dataset.white_background else [0, 0, 0] - background = torch.tensor(bg_color, dtype=dtype, device="cuda") + background = torch.tensor(bg_color, dtype=torch.float32, device="cuda") if not skip_train: render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background) @@ -63,6 +63,4 @@ def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParam # Initialize system state (RNG) safe_state(args.quiet) - dtype = torch.float32 - - render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, dtype) \ No newline at end of file + render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) \ No newline at end of file diff --git a/scene/cameras.py b/scene/cameras.py index 0609d0a46..5264a04c6 100644 --- a/scene/cameras.py +++ b/scene/cameras.py @@ -17,7 +17,8 @@ class Camera(nn.Module): def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, image_name, uid, - trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32 + trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda", data_dtype=torch.float32, + store_images_as_uint8=True, ): super(Camera, self).__init__() @@ -37,14 +38,18 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) self.data_device = torch.device("cuda") - self.original_image = image.clamp(0.0, 1.0).to(self.data_dtype).to(self.data_device) - self.image_width = self.original_image.shape[2] - self.image_height = self.original_image.shape[1] + self.store_images_as_uint8 = store_images_as_uint8 - if gt_alpha_mask is not None: - self.original_image *= gt_alpha_mask.to(self.data_dtype).to(self.data_device) - else: - self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) + self._original_image = image.to(self.data_device) + self._gt_alpha_mask = gt_alpha_mask + if self._gt_alpha_mask is not None: + self._gt_alpha_mask = self._gt_alpha_mask.to(self.data_device) + + if not store_images_as_uint8: + self._original_image = self.convert_image(self._original_image) + + self.image_width = self._original_image.shape[2] + self.image_height = self._original_image.shape[1] self.zfar = 100.0 self.znear = 0.01 @@ -57,6 +62,23 @@ def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) self.camera_center = self.world_view_transform.inverse()[3, :3] + def convert_image(self, image): + image = (image / 255.0).clamp(0.0, 1.0).to(self.data_dtype) + gt_alpha_mask = self._gt_alpha_mask + + if gt_alpha_mask is not None: + gt_alpha_mask = gt_alpha_mask / 255.0 + image *= gt_alpha_mask.to(self.data_dtype) + + return image + + @property + def original_image(self): + if self.store_images_as_uint8: + return self.convert_image(self._original_image) + else: + return self._original_image + class MiniCam: def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): self.image_width = width diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 6e762a5f6..400e7b605 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -52,7 +52,8 @@ def loadCam(args, id, cam_info, resolution_scale): FoVx=cam_info.FovX, FoVy=cam_info.FovY, image=gt_image, gt_alpha_mask=loaded_mask, image_name=cam_info.image_name, uid=id, data_device=args.data_device, - data_dtype=get_data_dtype(args.data_dtype)) + data_dtype=get_data_dtype(args.data_dtype), + store_images_as_uint8=args.store_images_as_uint8) def cameraList_from_camInfos(cam_infos, resolution_scale, args): camera_list = [] diff --git a/utils/general_utils.py b/utils/general_utils.py index ed7f0a6ed..4b0c53aec 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -20,7 +20,7 @@ def inverse_sigmoid(x): def PILtoTorch(pil_image, resolution): resized_image_PIL = pil_image.resize(resolution) - resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0 + resized_image = torch.from_numpy(np.array(resized_image_PIL))# / 255.0 if len(resized_image.shape) == 3: return resized_image.permute(2, 0, 1) else: From 39fb001ef0b4b08628be47dce62dbf42bff2215b Mon Sep 17 00:00:00 2001 From: Stefan Saraev Date: Fri, 17 May 2024 17:13:26 +0300 Subject: [PATCH 3/3] chore: prepare for pull-request -> removing debug commentaries -> removing unused proposed code --- render.py | 2 +- scene/gaussian_model.py | 27 +++++++++++++-------------- train.py | 3 +-- utils/camera_utils.py | 2 -- utils/general_utils.py | 16 ++++++++-------- 5 files changed, 23 insertions(+), 27 deletions(-) diff --git a/render.py b/render.py index fc54831db..fc6b82de8 100644 --- a/render.py +++ b/render.py @@ -62,5 +62,5 @@ def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParam # Initialize system state (RNG) safe_state(args.quiet) - + render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test) \ No newline at end of file diff --git a/scene/gaussian_model.py b/scene/gaussian_model.py index f7905588d..632a1e8e1 100644 --- a/scene/gaussian_model.py +++ b/scene/gaussian_model.py @@ -23,11 +23,11 @@ class GaussianModel: - def setup_functions(self, dtype): + def setup_functions(self): def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): - L = build_scaling_rotation(scaling_modifier * scaling, rotation, dtype) + L = build_scaling_rotation(scaling_modifier * scaling, rotation) actual_covariance = L @ L.transpose(1, 2) - symm = strip_symmetric(actual_covariance, dtype) + symm = strip_symmetric(actual_covariance) return symm self.scaling_activation = torch.exp @@ -41,7 +41,7 @@ def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): self.rotation_activation = torch.nn.functional.normalize - def __init__(self, sh_degree : int, dtype=torch.float32): + def __init__(self, sh_degree : int): self.active_sh_degree = 0 self.max_sh_degree = sh_degree self._xyz = torch.empty(0) @@ -56,8 +56,7 @@ def __init__(self, sh_degree : int, dtype=torch.float32): self.optimizer = None self.percent_dense = 0 self.spatial_lr_scale = 0 - self.dtype = dtype - self.setup_functions(dtype) + self.setup_functions() def capture(self): return ( @@ -137,7 +136,7 @@ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") rots[:, 0] = 1 - opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=self.dtype, device="cuda")) + opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) @@ -145,7 +144,7 @@ def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float): self._scaling = nn.Parameter(scales.requires_grad_(True)) self._rotation = nn.Parameter(rots.requires_grad_(True)) self._opacity = nn.Parameter(opacities.requires_grad_(True)) - self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda", dtype=self.dtype) + self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") def training_setup(self, training_args): self.percent_dense = training_args.percent_dense @@ -247,12 +246,12 @@ def load_ply(self, path): for idx, attr_name in enumerate(rot_names): rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) - self._xyz = nn.Parameter(torch.tensor(xyz, dtype=self.dtype, device="cuda").requires_grad_(True)) - self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) - self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=self.dtype, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) - self._opacity = nn.Parameter(torch.tensor(opacities, dtype=self.dtype, device="cuda").requires_grad_(True)) - self._scaling = nn.Parameter(torch.tensor(scales, dtype=self.dtype, device="cuda").requires_grad_(True)) - self._rotation = nn.Parameter(torch.tensor(rots, dtype=self.dtype, device="cuda").requires_grad_(True)) + self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) + self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) + self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) + self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) + self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) self.active_sh_degree = self.max_sh_degree diff --git a/train.py b/train.py index 7435218a9..5d819b348 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ from gaussian_renderer import render, network_gui import sys from scene import Scene, GaussianModel -from utils.general_utils import get_data_dtype, safe_state +from utils.general_utils import safe_state import uuid from tqdm import tqdm from utils.image_utils import psnr @@ -216,7 +216,6 @@ def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_i # Start GUI server, configure and run training network_gui.init(args.ip, args.port) torch.autograd.set_detect_anomaly(args.detect_anomaly) - training(lp.extract(args), op.extract(args), pp.extract(args), args.test_iterations, args.save_iterations, args.checkpoint_iterations, args.start_checkpoint, args.debug_from) # All done diff --git a/utils/camera_utils.py b/utils/camera_utils.py index 400e7b605..6c886f8a2 100644 --- a/utils/camera_utils.py +++ b/utils/camera_utils.py @@ -39,8 +39,6 @@ def loadCam(args, id, cam_info, resolution_scale): resolution = (int(orig_w / scale), int(orig_h / scale)) resized_image_rgb = PILtoTorch(cam_info.image, resolution) - - # resized_image_rgb = resized_image_rgb.to(get_data_dtype(args.data_dtype)) gt_image = resized_image_rgb[:3, ...] loaded_mask = None diff --git a/utils/general_utils.py b/utils/general_utils.py index 4b0c53aec..f060e14ed 100644 --- a/utils/general_utils.py +++ b/utils/general_utils.py @@ -20,7 +20,7 @@ def inverse_sigmoid(x): def PILtoTorch(pil_image, resolution): resized_image_PIL = pil_image.resize(resolution) - resized_image = torch.from_numpy(np.array(resized_image_PIL))# / 255.0 + resized_image = torch.from_numpy(np.array(resized_image_PIL)) if len(resized_image.shape) == 3: return resized_image.permute(2, 0, 1) else: @@ -61,8 +61,8 @@ def helper(step): return helper -def strip_lowerdiag(L, dtype=torch.float32): - uncertainty = torch.zeros((L.shape[0], 6), dtype=dtype, device="cuda") +def strip_lowerdiag(L): + uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") uncertainty[:, 0] = L[:, 0, 0] uncertainty[:, 1] = L[:, 0, 1] @@ -72,8 +72,8 @@ def strip_lowerdiag(L, dtype=torch.float32): uncertainty[:, 5] = L[:, 2, 2] return uncertainty -def strip_symmetric(sym, dtype=torch.float32): - return strip_lowerdiag(sym, dtype=dtype) +def strip_symmetric(sym): + return strip_lowerdiag(sym) def build_rotation(r): norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) @@ -98,8 +98,8 @@ def build_rotation(r): R[:, 2, 2] = 1 - 2 * (x*x + y*y) return R -def build_scaling_rotation(s, r, dtype=torch.float32): - L = torch.zeros((s.shape[0], 3, 3), dtype=dtype, device="cuda") +def build_scaling_rotation(s, r): + L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") R = build_rotation(r) L[:,0,0] = s[:,0] @@ -139,4 +139,4 @@ def get_data_dtype(dtype): return torch.float64 elif dtype == "float16": return torch.float16 - return torch.float32 \ No newline at end of file + return torch.float32