forked from KeKsBoTer/c3dgs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
132 lines (117 loc) · 4.61 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact [email protected]
#
import os
import random
import json
from utils.system_utils import searchForMaxIteration
from scene.dataset_readers import sceneLoadTypeCallbacks
from scene.gaussian_model import GaussianModel
from arguments import ModelParams
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
from glob import glob
class Scene:
gaussians: GaussianModel
def __init__(
self,
args: ModelParams,
gaussians: GaussianModel,
load_iteration=None,
shuffle=True,
resolution_scales=[1.0],
override_quantization=False
):
"""b
:param path: Path to colmap scene main folder.
"""
self.model_path = args.model_path
self.loaded_iter = None
self.gaussians = gaussians
if load_iteration:
if load_iteration == -1:
self.loaded_iter = searchForMaxIteration(
os.path.join(self.model_path, "point_cloud")
)
else:
self.loaded_iter = load_iteration
print("Loading trained model at iteration {}".format(self.loaded_iter))
self.train_cameras = {}
self.test_cameras = {}
if os.path.exists(os.path.join(args.source_path, "sparse")):
scene_info = sceneLoadTypeCallbacks["Colmap"](
args.source_path, args.images, args.eval
)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
print("Found transforms_train.json file, assuming Blender data set!")
scene_info = sceneLoadTypeCallbacks["Blender"](
args.source_path, args.white_background, args.eval
)
else:
assert False, "Could not recognize scene type!"
if not self.loaded_iter:
with open(scene_info.ply_path, "rb") as src_file, open(
os.path.join(self.model_path, "input.ply"), "wb"
) as dest_file:
dest_file.write(src_file.read())
json_cams = []
camlist = []
if scene_info.test_cameras:
camlist.extend(scene_info.test_cameras)
if scene_info.train_cameras:
camlist.extend(scene_info.train_cameras)
for id, cam in enumerate(camlist):
json_cams.append(camera_to_JSON(id, cam))
with open(os.path.join(self.model_path, "cameras.json"), "w") as file:
json.dump(json_cams, file)
if shuffle:
random.shuffle(
scene_info.train_cameras
) # Multi-res consistent random shuffling
random.shuffle(
scene_info.test_cameras
) # Multi-res consistent random shuffling
self.cameras_extent = scene_info.nerf_normalization["radius"]
for resolution_scale in resolution_scales:
print("Loading Training Cameras")
self.train_cameras[resolution_scale] = cameraList_from_camInfos(
scene_info.train_cameras, resolution_scale, args
)
print("Loading Test Cameras")
self.test_cameras[resolution_scale] = cameraList_from_camInfos(
scene_info.test_cameras, resolution_scale, args
)
if self.loaded_iter:
self.gaussians.load(
glob(
os.path.join(
self.model_path,
"point_cloud",
"iteration_" + str(self.loaded_iter),
"point_cloud.*",
)
)[0],
override_quantization=override_quantization
)
else:
raise Exception("no iteration to load was found")
def save(self, iteration, format="ply"):
point_cloud_path = os.path.join(
self.model_path, "point_cloud/iteration_{}".format(iteration)
)
if format == "ply":
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"))
elif format == "npz":
self.gaussians.save_npz(os.path.join(point_cloud_path, "point_cloud.npz"))
else:
raise Exception(f"format '{format}' not supported")
def getTrainCameras(self, scale=1.0):
return self.train_cameras[scale]
def getTestCameras(self, scale=1.0):
return self.test_cameras[scale]