Skip to content

Commit

Permalink
better numerical stability and a new option
Browse files Browse the repository at this point in the history
  • Loading branch information
thoglu committed Jan 23, 2024
1 parent b417bcd commit 601209c
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 59 deletions.
2 changes: 1 addition & 1 deletion jammy_flows/flow_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
opts_dict["f"]["kwargs"]["vertical_fix_boundary_derivative"] = (1, lambda x: [0,1])
opts_dict["f"]["kwargs"]["min_kappa"] = (1e-10, lambda x: x>0)
opts_dict["f"]["kwargs"]["kappa_prediction"] = ("direct_log_real_bounded", ["direct_log_real_bounded", "log_bounded"])

opts_dict["f"]["kwargs"]["add_extra_rotation_inbetween"] = (0, [0,1])

"""
Interval flows
Expand Down
89 changes: 68 additions & 21 deletions jammy_flows/layers/spheres/fvm_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@
import torch.autograd


"""
Implementation of the Fisher-von-Mises distribution as a normalizing flow.
"""

class fisher_von_mises_2d(sphere_base.sphere_base):

def __init__(self,
Expand All @@ -50,11 +46,12 @@ def __init__(self,
vertical_restrict_max_min_width_height_ratio=-1.0,
vertical_fix_boundary_derivative=1,
min_kappa=1e-10,
kappa_prediction="direct_log_real_bounded"):
kappa_prediction="direct_log_real_bounded",
add_extra_rotation_inbetween=0):
"""
Symbol: "f"
Based off of https://arxiv.org/abs/2002.02428, with additional FvM scalings and various options to play with.
Based off of https://arxiv.org/abs/2002.02428.
Parameters:
Expand Down Expand Up @@ -178,6 +175,7 @@ def __init__(self,
self.correlated_flow_params = nn.Parameter(torch.randn(1, self.total_num_correlated_params))


self.add_extra_rotation_inbetween=add_extra_rotation_inbetween

def _inv_flow_mapping(self, inputs, extra_inputs=None):

Expand Down Expand Up @@ -220,23 +218,23 @@ def _inv_flow_mapping(self, inputs, extra_inputs=None):

## go to cylinder from angle
prev_ret=torch.cos(x[:,:1])

fw_upd=torch.log(torch.sin(sphere_base.return_safe_angle_within_pi(x[:,0])))

log_det=log_det+fw_upd

## intermediate [-1,1]->[-1,1] transformation
safe_part=2*kappa
smaller_mask=kappa[:,0]<100

## safe_part only involves kappa
safe_part=torch.masked_scatter(input=safe_part, mask=smaller_mask[:,None], source=torch.log(torch.exp(2*kappa[smaller_mask])-1.0))

#prev_ret_inverse=self.z_scaling_factor*prev_ret
safe_ld_update=(torch.log(2*kappa)+kappa*(self.z_scaling_factor*prev_ret+1)-safe_part)[:,0]

## 1 + 1 - 2*k -2*(1+k(x-1)) / (-2k) = 2 - 2k -2 -2k(x-1) / -2k = 1 + 1(x-1)
ret= self.z_scaling_factor*((1.0+torch.exp(-2*kappa)-2*torch.exp(kappa*(self.z_scaling_factor*prev_ret-1)))/(-1+torch.exp(-2*kappa)))

switched=self.z_scaling_factor*prev_ret

ret= self.z_scaling_factor*((1.0+torch.exp(-2*kappa)-2*torch.exp(kappa*(self.z_scaling_factor*prev_ret-1)))/(-1+torch.exp(-2*kappa)))

#approx_result=ret # nothing happens for k->0
if(x.dtype==torch.float32):
kappa_mask=kappa<1e-4
Expand All @@ -251,10 +249,32 @@ def _inv_flow_mapping(self, inputs, extra_inputs=None):
log_det=log_det+safe_ld_update

### we have to make the angles safe here...TODO: change to external transformation
ret=torch.where(ret<-1.0, -1.0, ret)
ret=torch.where(ret>1.0, 1.0, ret)

ret=sphere_base.return_safe_costheta(ret)
angle=x[:,1:]


if(self.add_extra_rotation_inbetween):

ret=torch.acos(ret)

rev_upd=torch.log(torch.sin(sphere_base.return_safe_angle_within_pi(ret[:,0])))
log_det=log_det-rev_upd

comb=torch.cat([ret, angle],dim=1)

comb, log_det=self.spherical_to_eucl_embedding(comb, log_det)

inbetween_matrix=torch.Tensor([[0.0,0.0,1.0],[0.0,1.0,0.0],[-1.0,0.0,0.0]]).to(ret).type_as(ret).unsqueeze(0)

comb=torch.einsum("...ij,...j->...i", inbetween_matrix.permute(0,2,1), comb)

comb, log_det=self.eucl_to_spherical_embedding(comb, log_det)

ret=torch.cos(comb[:,:1])
fw_upd=torch.log(torch.sin(sphere_base.return_safe_angle_within_pi(comb[:,0])))
log_det=log_det+fw_upd

angle=comb[:,1:]

if(self.boundary_cos_theta_identity_region==0.0):

Expand Down Expand Up @@ -308,8 +328,7 @@ def _inv_flow_mapping(self, inputs, extra_inputs=None):

log_det=torch.masked_scatter(input=log_det, mask=contained_mask, source=log_det_contained)

ret=torch.where(ret<-1.0, -1.0, ret)
ret=torch.where(ret>1.0, 1.0, ret)
ret=sphere_base.return_safe_costheta(ret)

## go back to angle in a safe way
ret=torch.acos(ret)
Expand All @@ -319,11 +338,11 @@ def _inv_flow_mapping(self, inputs, extra_inputs=None):

ret=torch.cat([ret, angle], dim=1)


if(self.always_parametrize_in_embedding_space):

ret, log_det=self.spherical_to_eucl_embedding(ret, log_det)


return ret, log_det, sf_extra

def _flow_mapping(self, inputs, extra_inputs=None, sf_extra=None):
Expand Down Expand Up @@ -430,6 +449,32 @@ def _flow_mapping(self, inputs, extra_inputs=None, sf_extra=None):
log_det=torch.masked_scatter(input=log_det, mask=contained_mask, source=log_det_contained)


if(self.add_extra_rotation_inbetween):

## go back to angle

prev_ret=torch.acos(prev_ret)

rev_upd=torch.log(torch.sin(sphere_base.return_safe_angle_within_pi(prev_ret[:,0])))
log_det=log_det-rev_upd

comb=torch.cat([prev_ret, angle],dim=1)

comb, log_det=self.spherical_to_eucl_embedding(comb, log_det)

inbetween_matrix=torch.Tensor([[0.0,0.0,1.0],[0.0,1.0,0.0],[-1.0,0.0,0.0]]).to(prev_ret).type_as(prev_ret).unsqueeze(0)

comb=torch.einsum("...ij,...j->...i", inbetween_matrix, comb)

comb, log_det=self.eucl_to_spherical_embedding(comb, log_det)

prev_ret=torch.cos(comb[:,:1])
fw_upd=torch.log(torch.sin(sphere_base.return_safe_angle_within_pi(comb[:,0])))
log_det=log_det+fw_upd

angle=comb[:,1:]


## kappa->0

## 0.5+0.5x + (0.5-0.5x)*(1-2k) = 1 -k+kx = (1+k(x-1))^(1/k)
Expand All @@ -439,8 +484,10 @@ def _flow_mapping(self, inputs, extra_inputs=None, sf_extra=None):
#prev_ret_inverse=self.z_scaling_factor*prev_ret

log_det=log_det-torch.log(kappa*self.z_scaling_factor*prev_ret+kappa/torch.tanh(kappa))[:,0]


ret=self.z_scaling_factor*(1.0+(1.0/kappa)*torch.log( 0.5*(1.0+self.z_scaling_factor*prev_ret) + (0.5-0.5*self.z_scaling_factor*prev_ret)*torch.exp(-2.0*kappa) ))

if(x.dtype==torch.float32):
kappa_mask=kappa<1e-4
elif(x.dtype==torch.float64):
Expand All @@ -449,9 +496,9 @@ def _flow_mapping(self, inputs, extra_inputs=None, sf_extra=None):
raise Exception("Require 32 or 64 bit float")

ret=torch.where(kappa_mask, prev_ret, ret)
ret=torch.where(ret>1.0, 1.0, ret)
ret=torch.where(ret<-1.0, -1.0, ret)

ret=sphere_base.return_safe_costheta(ret)

## go back to angle
ret=torch.acos(ret)

Expand Down
89 changes: 52 additions & 37 deletions jammy_flows/layers/spheres/sphere_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,25 @@ def return_safe_angle_within_pi(x, safety_margin=1e-7):

return ret

def return_safe_costheta(x, safety_margin=None):
"""
Restricts the angle to not hit 0 or pi exactly.
"""
used_safety_margin=safety_margin
if(safety_margin is None):
if(x.dtype==torch.float32):
used_safety_margin=1e-7
elif(x.dtype==torch.float64):
used_safety_margin=1e-10

small_mask=x<(-1.0+used_safety_margin)
large_mask=x>(1.0-used_safety_margin)

ret=torch.where(small_mask, -1.0+used_safety_margin, x)
ret=torch.where(large_mask, 1.0-used_safety_margin, ret)

return ret

class sphere_base(layer_base.layer_base):

def __init__(self,
Expand Down Expand Up @@ -266,10 +285,11 @@ def inplane_euclidean_to_spherical(self, x, log_det):
transformed_coords=[]
keep_sign=None

radius=(x**2).sum(dim=1, keepdims=True).sqrt()

for ind in range(self.dimension):
if(ind==0):
radius=(x**2).sum(dim=1, keepdims=True).sqrt()


## we dont want radii of exactly 0 normally
## but we can allow it because we drop the usual log_det(radius) below aswell
#radius[radius==0]=1e-10
Expand All @@ -284,23 +304,22 @@ def inplane_euclidean_to_spherical(self, x, log_det):
## standard jacobian
# log_det+=-torch.log(radius[:,0])*(self.dimension-1)

else:


mod_ind=ind-1
elif(ind==1):
nominator=x[:,:1]

new_angle=torch.acos(x[:,mod_ind:mod_ind+1]/torch.sum(x[:,mod_ind:]**2, dim=1, keepdims=True).sqrt())
new_argument=torch.where(radius==0, 1.0, nominator/radius)
new_angle=torch.acos(new_argument)

if(ind==self.dimension-1):
## last one, check sign flip
mask_smaller=(x[:,ind:ind+1]<0)#.double()
new_angle=torch.where(mask_smaller, 2*numpy.pi-new_angle, new_angle)
#new_angle=mask_smaller*(2*numpy.pi-new_angle)+(1.0-mask_smaller)*new_angle
else:
raise NotImplementedError("D>2 not implemented for D-spheres currently")
log_det=log_det+torch.log(torch.sin(new_angle[:,0]))*(self.dimension-1-ind)
mask_smaller=(x[:,1:2]<0)
new_angle=torch.where(mask_smaller, 2*numpy.pi-new_angle, new_angle)

transformed_coords.append(new_angle)
else:
# D>2 spheres not implemented currently
#mod_ind=ind-1
#new_angle=torch.acos(x[:,mod_ind:mod_ind+1]/torch.sum(x[:,mod_ind:]**2, dim=1, keepdims=True).sqrt())
#log_det=log_det+torch.log(torch.sin(new_angle[:,0]))*(self.dimension-1-ind)
raise Exception("Higher order spheres not supported right now!")


return torch.cat(transformed_coords, dim=1), log_det, keep_sign
Expand Down Expand Up @@ -393,24 +412,21 @@ def sphere_to_plane(self, x, log_det, sf_extra=None):

else:

cos_x=torch.cos(x[:,0:1])

good_cos_x=(cos_x!=1.0) & (cos_x!=-1.0)

cos_x=torch.where(cos_x==1.0, cos_x-1e-6, cos_x)
cos_x=torch.where(cos_x==-1.0, cos_x+1e-6, cos_x)
safe_theta=return_safe_angle_within_pi(x[:,0:1])

cos_x=torch.cos(safe_theta)
## TODO: check if we really need 1e-6 here, but it seems to be so
cos_x=return_safe_costheta(cos_x, safety_margin=1e-6)

#cos_x=(cos_x==1.0)*(cos_x-1e-5)+(cos_x==-1.0)*(cos_x+1e-5)+good_cos_x*cos_x
r_g=torch.sqrt(-torch.log( (1.0-cos_x)/2.0 )*2.0)

#inner=1.0-2.0*torch.exp(-((r_g)**2)/2.0)

## the normal log_det .. we use another factor that drops the r term and is in concordance with *inplane_spherical_to_euclidean* definition
## we also drop the sin(theta) factor, to be in accord with the spherical measure
### FULL TERM:
### log_det+=-torch.log(r_g[:,0])-torch.log(1.0-cos_x[:,0])+torch.log(torch.sin(x[:,0]))
log_det=log_det-torch.log(1.0-cos_x[:,0])+torch.log(torch.sin(x[:,0]))


log_det=log_det-torch.log(1.0-cos_x[:,0])+torch.log(torch.sin(safe_theta[:,0]))

x=torch.cat([r_g, x[:,1:2]],dim=1)

else:
Expand Down Expand Up @@ -471,7 +487,7 @@ def plane_to_sphere(self, x, log_det):

## pi-angle leads to a flipped stereographic projection
new_theta=torch.acos(1.0-2.0*torch.exp(-((x[:,0:1])**2)/2.0))

new_theta=return_safe_angle_within_pi(new_theta)
#r_g=x[:,0]
#inner=1.0-2.0*torch.exp(-((r_g)**2)/2.0)

Expand Down Expand Up @@ -502,15 +518,16 @@ def plane_to_sphere(self, x, log_det):
def inv_flow_mapping(self, inputs, extra_inputs=None, include_area_element=True):

[x, log_det] = inputs

## (1) apply inverse householder rotation if desired



if(self.add_rotation):


if(self.always_parametrize_in_embedding_space==False):
x, log_det=self.spherical_to_eucl_embedding(x, log_det)

## householder dimension is one higher than sphere dimension (we rotate in embedding space)


Expand All @@ -531,7 +548,7 @@ def inv_flow_mapping(self, inputs, extra_inputs=None, include_area_element=True)
inv_flow_results = self._inv_flow_mapping([x, log_det], extra_inputs=extra_inputs[:, self.num_householder_params:])

x, log_det = inv_flow_results[:2]

if(self.higher_order_cylinder_parametrization):
sf_extra=inv_flow_results[2]

Expand All @@ -542,10 +559,9 @@ def inv_flow_mapping(self, inputs, extra_inputs=None, include_area_element=True)

if(self.always_parametrize_in_embedding_space and sf_extra is None):
x, log_det=self.eucl_to_spherical_embedding(x, log_det)

#sys.exit(-1)

x, log_det=self.sphere_to_plane(x, log_det, sf_extra=sf_extra)

return x, log_det

## flow mapping (sampling pass)
Expand Down Expand Up @@ -576,7 +592,7 @@ def flow_mapping(self,inputs, extra_inputs=None):
x, log_det=self.spherical_to_eucl_embedding(x, log_det)

#xy=torch.cat((x, y), dim=1)

mat=self.compute_rotation_matrix(x,extra_inputs=extra_inputs, mode=self.rotation_mode, device=x.device)

#x = torch.bmm(mat, x.unsqueeze(-1)).squeeze(-1) # uncomment
Expand Down Expand Up @@ -636,8 +652,7 @@ def return_problematic_pars_between_hh_and_intrinsic(self, x, extra_inputs=None,
new_pts,_=self.eucl_to_spherical_embedding(eucl,0.0)

mask=(new_pts[:,0]<flag_pole_distance) | (new_pts[:,0] >(numpy.pi-flag_pole_distance))
#mask=(new_pts[:,1]<flag_pole_distance) | (new_pts[:,1] >(2*numpy.pi-flag_pole_distance))


problematic_points=x[mask]

return problematic_points
Expand Down

0 comments on commit 601209c

Please sign in to comment.