-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathinerf_helpers.py
32 lines (25 loc) · 1.13 KB
/
inerf_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
import torch.nn as nn
def vec2ss_matrix(vector): # vector to skewsym. matrix
ss_matrix = torch.zeros((3,3))
ss_matrix[0, 1] = -vector[2]
ss_matrix[0, 2] = vector[1]
ss_matrix[1, 0] = vector[2]
ss_matrix[1, 2] = -vector[0]
ss_matrix[2, 0] = -vector[1]
ss_matrix[2, 1] = vector[0]
return ss_matrix
class camera_transf(nn.Module):
def __init__(self):
super(camera_transf, self).__init__()
self.w = nn.Parameter(torch.normal(0., 1e-6, size=(3,)))
self.v = nn.Parameter(torch.normal(0., 1e-6, size=(3,)))
self.theta = nn.Parameter(torch.normal(0., 1e-6, size=()))
def forward(self, x):
exp_i = torch.zeros((4,4))
w_skewsym = vec2ss_matrix(self.w)
exp_i[:3, :3] = torch.eye(3) + torch.sin(self.theta) * w_skewsym + (1 - torch.cos(self.theta)) * torch.matmul(w_skewsym, w_skewsym)
exp_i[:3, 3] = torch.matmul(torch.eye(3) * self.theta + (1 - torch.cos(self.theta)) * w_skewsym + (self.theta - torch.sin(self.theta)) * torch.matmul(w_skewsym, w_skewsym), self.v)
exp_i[3, 3] = 1.
T_i = torch.matmul(exp_i, x)
return T_i