Skip to content

Commit

Permalink
sparse optimization for embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
yashbhalgat committed Feb 3, 2022
1 parent 7d56ffe commit f8a300e
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 11 deletions.
4 changes: 2 additions & 2 deletions hash_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, bounding_box, n_levels=16, n_features_per_level=2,\
self.b = torch.exp((torch.log(self.finest_resolution)-torch.log(self.base_resolution))/(n_levels-1))

self.embeddings = nn.ModuleList([nn.Embedding(2**self.log2_hashmap_size, \
self.n_features_per_level) for i in range(n_levels)])
self.n_features_per_level, sparse=True) for i in range(n_levels)])
# custom uniform initialization
for i in range(n_levels):
nn.init.uniform_(self.embeddings[i].weight, a=-0.0001, b=0.0001)
Expand Down Expand Up @@ -153,4 +153,4 @@ def forward(self, input, **kwargs):
result[..., 23] = self.C4[7] * xz * (xx - 3 * yy)
result[..., 24] = self.C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))

return result
return result
44 changes: 44 additions & 0 deletions optimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#coding:utf-8
import os, sys
import os.path as osp
import numpy as np
import torch
from torch import nn
from torch.optim import Optimizer
from functools import reduce
from torch.optim import AdamW

class MultiOptimizer:
def __init__(self, optimizers={}):
self.optimizers = optimizers
self.keys = list(optimizers.keys())
self.param_groups = reduce(lambda x,y: x+y, [v.param_groups for v in self.optimizers.values()])

def state_dict(self):
state_dicts = [(key, self.optimizers[key].state_dict())\
for key in self.keys]
return state_dicts

def load_state_dict(self, state_dict):
for key, val in state_dict:
try:
self.optimizers[key].load_state_dict(val)
except:
print("Unloaded %s" % key)

def step(self, key=None, scaler=None):
keys = [key] if key is not None else self.keys
_ = [self._step(key, scaler) for key in keys]

def _step(self, key, scaler=None):
if scaler is not None:
scaler.step(self.optimizers[key])
scaler.update()
else:
self.optimizers[key].step()

def zero_grad(self, key=None):
if key is not None:
self.optimizers[key].zero_grad()
else:
_ = [self.optimizers[key].zero_grad() for key in self.keys]
13 changes: 9 additions & 4 deletions run_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import matplotlib.pyplot as plt

from run_nerf_helpers import *
from optimizer import MultiOptimizer

from load_llff import load_llff_data
from load_deepvoxels import load_dv_data
Expand Down Expand Up @@ -234,10 +235,13 @@ def create_nerf(args):

# Create optimizer
if args.i_embed==1:
optimizer = torch.optim.Adam([
{'params': grad_vars, 'weight_decay': 1e-6},
{'params': embedding_params, 'eps': 1e-15}
], lr=args.lrate, betas=(0.9, 0.99))
sparse_opt = torch.optim.SparseAdam(embedding_params, lr=args.lrate, betas=(0.9, 0.99), eps=1e-15)
dense_opt = torch.optim.Adam(grad_vars, lr=args.lrate, betas=(0.9, 0.99), weight_decay=1e-6)
optimizer = MultiOptimizer(optimizers={"sparse_opt": sparse_opt, "dense_opt": dense_opt})
# optimizer = torch.optim.Adam([
# {'params': grad_vars, 'weight_decay': 1e-6},
# {'params': embedding_params, 'eps': 1e-15}
# ], lr=args.lrate, betas=(0.9, 0.99))
else:
optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))

Expand Down Expand Up @@ -682,6 +686,7 @@ def train():
args.expname += "_posVIEW"
args.expname += "_fine"+str(args.finest_res) + "_log2T"+str(args.log2_hashmap_size)
args.expname += "_lr"+str(args.lrate) + "_decay"+str(args.lrate_decay)
args.expname += "_sparseopt"
#args.expname += datetime.now().strftime('_%H_%M_%d_%m_%Y')
expname = args.expname

Expand Down
4 changes: 1 addition & 3 deletions scripts/make_gif.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
image_idx = "000"

paths = {
"Vanilla Slow": "../logs/blender_chair_posXYZ_posVIEW_fine1024_log2T19_lr0.0005_decay500", \
"Hashed Superfast": "../logs/blender_chair_hashXYZ_sphereVIEW_fine1024_log2T19_lr0.01_decay10", \
"Hashed Fast": "../logs/blender_chair_hashXYZ_sphereVIEW_fine1024_log2T19_lr0.01_decay100"}
"Hashed": "../logs/blender_hotdog_hashXYZ_sphereVIEW_fine1024_log2T19_lr0.01_decay10"}

for path_name, log_path in paths.items():
folders = [name for name in os.listdir(log_path) if name.startswith("renderonly_path_")]
Expand Down
4 changes: 2 additions & 2 deletions scripts/run_all_checkpoints.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
for i in logs/blender_chair_posXYZ_posVIEW_fine1024_log2T19_lr0.0005_decay500/*.tar; do
CUDA_VISIBLE_DEVICES=2 python run_nerf.py --config configs/chair.txt --finest_res 1024 --i_embed 0 --i_embed_views 0 --render_only --ft_path $i
for i in logs/blender_hotdog_hashXYZ_sphereVIEW_fine1024_log2T19_lr0.01_decay10/*.tar; do
CUDA_VISIBLE_DEVICES=3 python run_nerf.py --config configs/hotdog.txt --finest_res 1024 --lr 0.01 --lr_decay 10 --render_only --ft_path $i
done

0 comments on commit f8a300e

Please sign in to comment.