Skip to content

Commit

Permalink
Merge pull request #25 from johannbrehmer/master
Browse files Browse the repository at this point in the history
Fix splines crashing with all-tail inputs
  • Loading branch information
arturbekasov authored Nov 9, 2020
2 parents 761b790 + 3ffd42b commit 4c5bbfe
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 54 deletions.
33 changes: 17 additions & 16 deletions nflows/transforms/splines/cubic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,23 @@ def unconstrained_cubic_spline(
else:
raise RuntimeError("{} tails are not implemented.".format(tails))

outputs[inside_interval_mask], logabsdet[inside_interval_mask] = cubic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnorm_derivatives_left=unnorm_derivatives_left[inside_interval_mask, :],
unnorm_derivatives_right=unnorm_derivatives_right[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
eps=eps,
quadratic_threshold=quadratic_threshold,
)
if torch.any(inside_interval_mask):
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = cubic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnorm_derivatives_left=unnorm_derivatives_left[inside_interval_mask, :],
unnorm_derivatives_right=unnorm_derivatives_right[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
eps=eps,
quadratic_threshold=quadratic_threshold,
)

return outputs, logabsdet

Expand Down
19 changes: 10 additions & 9 deletions nflows/transforms/splines/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ def unconstrained_linear_spline(
else:
raise RuntimeError("{} tails are not implemented.".format(tails))

outputs[inside_interval_mask], logabsdet[inside_interval_mask] = linear_spline(
inputs=inputs[inside_interval_mask],
unnormalized_pdf=unnormalized_pdf[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
)
if torch.any(inside_interval_mask):
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = linear_spline(
inputs=inputs[inside_interval_mask],
unnormalized_pdf=unnormalized_pdf[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
)

return outputs, logabsdet

Expand Down
25 changes: 13 additions & 12 deletions nflows/transforms/splines/quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,19 @@ def unconstrained_quadratic_spline(
else:
raise RuntimeError("{} tails are not implemented.".format(tails))

outputs[inside_interval_mask], logabsdet[inside_interval_mask] = quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
)
if torch.any(inside_interval_mask):
outputs[inside_interval_mask], logabsdet[inside_interval_mask] = quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
)

return outputs, logabsdet

Expand Down
35 changes: 18 additions & 17 deletions nflows/transforms/splines/rational_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,24 @@ def unconstrained_rational_quadratic_spline(
else:
raise RuntimeError("{} tails are not implemented.".format(tails))

(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)
if torch.any(inside_interval_mask):
(
outputs[inside_interval_mask],
logabsdet[inside_interval_mask],
) = rational_quadratic_spline(
inputs=inputs[inside_interval_mask],
unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
inverse=inverse,
left=-tail_bound,
right=tail_bound,
bottom=-tail_bound,
top=tail_bound,
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
)

return outputs, logabsdet

Expand Down
29 changes: 29 additions & 0 deletions tests/transforms/splines/cubic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,32 @@ def call_spline_fn(inputs, inverse=False):
self.eps = 1e-4
self.assertEqual(inputs, inputs_inv)
self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet))

def test_forward_inverse_are_consistent_in_tails(self):
num_bins = 10
shape = [2, 3, 4]
tail_bound = 1.0

unnormalized_widths = torch.randn(*shape, num_bins)
unnormalized_heights = torch.randn(*shape, num_bins)
unnorm_derivatives_left = torch.randn(*shape, 1)
unnorm_derivatives_right = torch.randn(*shape, 1)

def call_spline_fn(inputs, inverse=False):
return splines.unconstrained_cubic_spline(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnorm_derivatives_left=unnorm_derivatives_left,
unnorm_derivatives_right=unnorm_derivatives_right,
inverse=inverse,
tail_bound=tail_bound
)

inputs = torch.sign(torch.randn(*shape)) * (tail_bound + torch.rand(*shape)) # Now *all* inputs are outside [-tail_bound, tail_bound].
outputs, logabsdet = call_spline_fn(inputs, inverse=False)
inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True)

self.eps = 1e-4
self.assertEqual(inputs, inputs_inv)
self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet))
20 changes: 20 additions & 0 deletions tests/transforms/splines/linear_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,23 @@ def call_spline_fn(inputs, inverse=False):
self.eps = 1e-4
self.assertEqual(inputs, inputs_inv)
self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet))

def test_forward_inverse_are_consistent_in_tails(self):
num_bins = 10
shape = [2, 3, 4]
tail_bound = 1.0

unnormalized_pdf = torch.randn(*shape, num_bins)

def call_spline_fn(inputs, inverse=False):
return splines.unconstrained_linear_spline(
inputs=inputs, unnormalized_pdf=unnormalized_pdf, inverse=inverse, tail_bound=tail_bound
)

inputs = torch.sign(torch.randn(*shape)) * (tail_bound + torch.rand(*shape)) # Now *all* inputs are outside [-tail_bound, tail_bound].
outputs, logabsdet = call_spline_fn(inputs, inverse=False)
inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True)

self.eps = 1e-4
self.assertEqual(inputs, inputs_inv)
self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet))
25 changes: 25 additions & 0 deletions tests/transforms/splines/quadratic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,28 @@ def call_spline_fn(inputs, inverse=False):
self.eps = 1e-4
self.assertEqual(inputs, inputs_inv)
self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet))

def test_forward_inverse_are_consistent_in_tails(self):
num_bins = 10
shape = [2, 3, 4]
tail_bound = 1.0

unnormalized_widths = torch.randn(*shape, num_bins)
unnormalized_heights = torch.randn(*shape, num_bins - 1)

def call_spline_fn(inputs, inverse=False):
return splines.unconstrained_quadratic_spline(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
inverse=inverse,
tail_bound=tail_bound,
)

inputs = torch.sign(torch.randn(*shape)) * (tail_bound + torch.rand(*shape)) # Now *all* inputs are outside [-tail_bound, tail_bound].
outputs, logabsdet = call_spline_fn(inputs, inverse=False)
inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True)

self.eps = 1e-4
self.assertEqual(inputs, inputs_inv)
self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet))
26 changes: 26 additions & 0 deletions tests/transforms/splines/rational_quadratic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,29 @@ def call_spline_fn(inputs, inverse=False):
self.eps = 1e-4
self.assertEqual(inputs, inputs_inv)
self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet))

def test_forward_inverse_are_consistent_in_tails(self):
num_bins = 10
shape = [2, 3, 4]
tail_bound = 1.0

unnormalized_widths = torch.randn(*shape, num_bins)
unnormalized_heights = torch.randn(*shape, num_bins)
unnormalized_derivatives = torch.randn(*shape, num_bins + 1)

def call_spline_fn(inputs, inverse=False):
return splines.unconstrained_rational_quadratic_spline(
inputs=inputs,
unnormalized_widths=unnormalized_widths,
unnormalized_heights=unnormalized_heights,
unnormalized_derivatives=unnormalized_derivatives,
inverse=inverse,
)

inputs = torch.sign(torch.randn(*shape)) * (tail_bound + torch.rand(*shape)) # Now *all* inputs are outside [-tail_bound, tail_bound].
outputs, logabsdet = call_spline_fn(inputs, inverse=False)
inputs_inv, logabsdet_inv = call_spline_fn(outputs, inverse=True)

self.eps = 1e-4
self.assertEqual(inputs, inputs_inv)
self.assertEqual(logabsdet + logabsdet_inv, torch.zeros_like(logabsdet))

0 comments on commit 4c5bbfe

Please sign in to comment.