Skip to content

Commit

Permalink
fix numerical issues with spherical flows
Browse files Browse the repository at this point in the history
  • Loading branch information
thoglu committed Jan 3, 2024
1 parent 9d45a26 commit e1f2e08
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 41 deletions.
25 changes: 17 additions & 8 deletions jammy_flows/layers/spheres/fvm_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ def _inv_flow_mapping(self, inputs, extra_inputs=None):
log_det=log_det+safe_ld_update

### we have to make the angles safe here...TODO: change to external transformation
ret=torch.where(ret<=-1.0, -1.0+1e-7, ret)
ret=torch.where(ret>=1.0, 1.0-1e-7, ret)
ret=torch.where(ret<-1.0, -1.0, ret)
ret=torch.where(ret>1.0, 1.0, ret)

angle=x[:,1:]

Expand Down Expand Up @@ -307,14 +307,19 @@ def _inv_flow_mapping(self, inputs, extra_inputs=None):
angle=torch.masked_scatter(input=angle, mask=contained_mask[:,None], source=angle_contained)

log_det=torch.masked_scatter(input=log_det, mask=contained_mask, source=log_det_contained)


ret=torch.where(ret<-1.0, -1.0, ret)
ret=torch.where(ret>1.0, 1.0, ret)

## go back to angle in a safe way
ret=sphere_base.return_safe_angle_within_pi(torch.acos(ret))
rev_upd=torch.log(torch.sin(ret))[:,0]
ret=torch.acos(ret)
rev_upd=torch.log(torch.sin(sphere_base.return_safe_angle_within_pi(ret[:,0])))

log_det=log_det-rev_upd

ret=torch.cat([ret, angle], dim=1)


if(self.always_parametrize_in_embedding_space):

ret, log_det=self.spherical_to_eucl_embedding(ret, log_det)
Expand All @@ -324,7 +329,7 @@ def _inv_flow_mapping(self, inputs, extra_inputs=None):
def _flow_mapping(self, inputs, extra_inputs=None, sf_extra=None):

[x,log_det]=inputs

if(self.always_parametrize_in_embedding_space):
x, log_det=self.eucl_to_spherical_embedding(x, log_det)

Expand Down Expand Up @@ -444,12 +449,16 @@ def _flow_mapping(self, inputs, extra_inputs=None, sf_extra=None):
raise Exception("Require 32 or 64 bit float")

ret=torch.where(kappa_mask, prev_ret, ret)

ret=torch.where(ret>1.0, 1.0, ret)
ret=torch.where(ret<-1.0, -1.0, ret)

## go back to angle
ret=torch.acos(ret)
rev_upd=torch.log(torch.sin(sphere_base.return_safe_angle_within_pi(ret)))[:,0]

rev_upd=torch.log(torch.sin(sphere_base.return_safe_angle_within_pi(ret[:,0])))
log_det=log_det-rev_upd

# join angles again
ret=torch.cat([ret, angle], dim=1)

if(self.always_parametrize_in_embedding_space):
Expand Down
44 changes: 11 additions & 33 deletions jammy_flows/layers/spheres/sphere_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .. import layer_base
import itertools

def return_safe_angle_within_pi(x, safety_margin=1e-10):
def return_safe_angle_within_pi(x, safety_margin=1e-7):
"""
Restricts the angle to not hit 0 or pi exactly.
"""
Expand Down Expand Up @@ -189,8 +189,13 @@ def eucl_to_spherical_embedding(self, x, log_det):
log_det=log_det-torch.log(torch.sin(angles[-1])).sum(axis=-1)

# phi
new_angle=torch.acos(x[:,0:1]/torch.sum(x[:,:2]**2, dim=1, keepdims=True).sqrt())
acos_arg=x[:,0:1]/torch.sum(x[:,:2]**2, dim=1, keepdims=True).sqrt()
acos_arg=torch.where(acos_arg>1.0, 1.0, acos_arg)
acos_arg=torch.where(acos_arg<-1.0, -1.0, acos_arg)

new_angle=torch.acos(acos_arg)
new_angle=torch.where(x[:,1:2]<0, 2*numpy.pi-new_angle, new_angle)

angles.append(new_angle)

return torch.cat(angles, dim=1), log_det
Expand Down Expand Up @@ -385,7 +390,6 @@ def sphere_to_plane(self, x, log_det, sf_extra=None):

x[:,0:1]=torch.exp(lnr)#

#print("log_det sphere_to_plane ", (-x[:,0:1]).sum(axis=-1))

else:

Expand Down Expand Up @@ -421,7 +425,7 @@ def plane_to_sphere(self, x, log_det):

x, log_det, keep_sign=self.inplane_euclidean_to_spherical(x, log_det)
sf_extra=None
#print("log det initial", log_det)

## first coordinate is now radial coordinate, other coordinates are angles
if(self.dimension==1):

Expand Down Expand Up @@ -463,8 +467,6 @@ def plane_to_sphere(self, x, log_det):
sf_extra=sfcyl
x[:,0:1]=lncyl

#print("log_det planetosphere ", (lncyl).sum(axis=-1))

else:

## pi-angle leads to a flipped stereographic projection
Expand Down Expand Up @@ -502,7 +504,7 @@ def inv_flow_mapping(self, inputs, extra_inputs=None, include_area_element=True)
[x, log_det] = inputs

## (1) apply inverse householder rotation if desired

if(self.add_rotation):


Expand All @@ -522,19 +524,6 @@ def inv_flow_mapping(self, inputs, extra_inputs=None, include_area_element=True)
if(self.always_parametrize_in_embedding_space==False):
x, log_det=self.eucl_to_spherical_embedding(x, log_det)


## correction required on 2-sphere
"""
if(self.dimension==2):
#print("inv 1) x ", x[:,0:1])
safe_angles=self.return_safe_angle_within_pi(x[:,0:1])
x=torch.cat( [safe_angles, x[:,1:]], dim=1)
#log_det+=-torch.log(torch.sin(x[:,0]))
"""
#print("inv 1) ld ", -torch.log(torch.sin(x[:,0])))
## (2) apply sphere intrinsic inverse flow function
## in 1 d case, _inv_flow_mapping should take as input values between 0 and 2pi, and outputs values between -pi and pi for easier further processing

sf_extra=None
if(extra_inputs is None):
inv_flow_results = self._inv_flow_mapping([x, log_det])
Expand All @@ -556,8 +545,7 @@ def inv_flow_mapping(self, inputs, extra_inputs=None, include_area_element=True)

#sys.exit(-1)
x, log_det=self.sphere_to_plane(x, log_det, sf_extra=sf_extra)



return x, log_det

## flow mapping (sampling pass)
Expand All @@ -579,16 +567,6 @@ def flow_mapping(self,inputs, extra_inputs=None):
else:
x,log_det = self._flow_mapping([x, log_det], extra_inputs=extra_inputs[:, self.num_householder_params:], sf_extra=sf_extra)

## safety check on 2-sphere
"""
if(self.dimension==2):
safe_angles=self.return_safe_angle_within_pi(x[:,0:1])
x=torch.cat( [safe_angles, x[:,1:]], dim=1)
#log_det+=torch.log(torch.sin(x[:,0]))
"""

## (3) apply householder rotation in embedding space if wanted
#extra_input_counter=0
Expand All @@ -606,6 +584,7 @@ def flow_mapping(self,inputs, extra_inputs=None):
## use broadcasting
x = torch.einsum("...ij, ...j -> ...i", mat, x)
if(self.always_parametrize_in_embedding_space==False):

x, log_det=self.eucl_to_spherical_embedding(x, log_det)

return x,log_det
Expand Down Expand Up @@ -661,7 +640,6 @@ def return_problematic_pars_between_hh_and_intrinsic(self, x, extra_inputs=None,

problematic_points=x[mask]

#print("problematic points", problematic_points)
return problematic_points

def obtain_layer_param_structure(self, param_dict, extra_inputs=None, previous_x=None, extra_prefix=""):
Expand Down

0 comments on commit e1f2e08

Please sign in to comment.