Skip to content

Commit

Permalink
implement GPD from wikipedia definition with 3 params
Browse files Browse the repository at this point in the history
  • Loading branch information
kashif committed Sep 13, 2024
1 parent 7668ce1 commit e2c2962
Show file tree
Hide file tree
Showing 2 changed files with 262 additions and 118 deletions.
256 changes: 140 additions & 116 deletions src/gluonts/torch/distributions/generalized_pareto.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from numbers import Number
from typing import Dict, Optional, Tuple, cast
import math
from numbers import Number, Real
from typing import Dict, Tuple, cast

import numpy as np
import torch
from torch import nan, inf
import torch.nn.functional as F
from torch.distributions import Distribution, constraints
from torch.distributions.utils import broadcast_all

Expand All @@ -26,118 +28,154 @@

class GeneralizedPareto(Distribution):
r"""
Generalised Pareto distribution.
Parameters
----------
xi
Tensor containing the xi (heaviness) shape parameters. The tensor is
of shape (*batch_shape, 1)
beta
Tensor containing the beta scale parameters. The tensor is of
shape (*batch_shape, 1)
Creates a Generalized Pareto distribution parameterized by :attr:`loc`, :attr:`scale`, and :attr:`concentration`.
The Generalized Pareto distribution is a family of continuous probability distributions on the real line.
Special cases include Exponential (when :attr:`loc` = 0, :attr:`concentration` = 0), Pareto (when :attr:`concentration` > 0,
:attr:`loc` = :attr:`scale` / :attr:`concentration`), and Uniform (when :attr:`concentration` = -1).
This distribution is often used to model the tails of other distributions. This implementation is based on the implementation in TensorFlow Probability.
Example::
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
>>> m = GeneralizedPareto(torch.tensor([0.1]), torch.tensor([2.0]), torch.tensor([0.4]))
>>> m.sample() # sample from a Generalized Pareto distribution with loc=1, scale=1, and concentration=1
tensor([ 1.5623])
Args:
loc (float or Tensor): Location parameter of the distribution
scale (float or Tensor): Scale parameter of the distribution
concentration (float or Tensor): Concentration parameter of the distribution
"""

arg_constraints = {
"xi": constraints.positive,
"beta": constraints.positive,
"loc": constraints.real,
"scale": constraints.positive,
"concentration": constraints.real,
}
support = constraints.positive
has_rsample = False

def __init__(self, xi, beta, validate_args=None):
self.xi, self.beta = broadcast_all(
xi.squeeze(dim=-1), beta.squeeze(dim=-1)
def __init__(self, loc, scale, concentration, validate_args=None):
self.loc, self.scale, self.concentration = broadcast_all(
loc, scale, concentration
)
if (
isinstance(loc, Number)
and isinstance(scale, Number)
and isinstance(concentration, Number)
):
batch_shape = torch.Size()
else:
batch_shape = self.loc.size()
super().__init__(batch_shape, validate_args=validate_args)

def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(GeneralizedPareto, _instance)
batch_shape = torch.Size(batch_shape)
new.loc = self.loc.expand(batch_shape)
new.scale = self.scale.expand(batch_shape)
new.concentration = self.concentration.expand(batch_shape)
super(GeneralizedPareto, new).__init__(
batch_shape, validate_args=False
)
new._validate_args = self._validate_args
return new

def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
with torch.no_grad():
u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device)
return self.icdf(u)

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
z = self._z(value)
eq_zero = torch.isclose(self.concentration, torch.tensor(0.0))
safe_conc = torch.where(
eq_zero, torch.ones_like(self.concentration), self.concentration
)
y = 1 / safe_conc + torch.ones_like(z)
where_nonzero = torch.where(y == 0, y, y * torch.log1p(safe_conc * z))
log_scale = (
math.log(self.scale)
if isinstance(self.scale, Real)
else self.scale.log()
)
return -log_scale - torch.where(eq_zero, z, where_nonzero)

def log_survival_function(self, value):
if self._validate_args:
self._validate_sample(value)
z = self._z(value)
eq_zero = torch.isclose(self.concentration, torch.tensor(0.0))
safe_conc = torch.where(
eq_zero, torch.ones_like(self.concentration), self.concentration
)
where_nonzero = -torch.log1p(safe_conc * z) / safe_conc
return torch.where(eq_zero, -z, where_nonzero)

setattr(self, "xi", xi)
setattr(self, "beta", beta)
def log_cdf(self, value):
return torch.log1p(-torch.exp(self.log_survival_function(value)))

super(GeneralizedPareto, self).__init__()
def cdf(self, value):
return torch.exp(self.log_cdf(value))

if isinstance(xi, Number) and isinstance(beta, Number):
batch_shape = torch.Size()
else:
batch_shape = self.xi.size()
super(GeneralizedPareto, self).__init__(
batch_shape, validate_args=validate_args
def icdf(self, value):
loc = self.loc
scale = self.scale
concentration = self.concentration
eq_zero = torch.isclose(concentration, torch.zeros_like(concentration))
safe_conc = torch.where(
eq_zero, torch.ones_like(concentration), concentration
)
logu = torch.log1p(-value)
where_nonzero = loc + scale / safe_conc * torch.expm1(
-safe_conc * logu
)
where_zero = loc - scale * logu
return torch.where(eq_zero, where_zero, where_nonzero)

if (
self._validate_args
and not torch.lt(-self.beta, torch.zeros_like(self.beta)).all()
):
raise ValueError("GenPareto is not defined when scale beta<=0")
def _z(self, x):
return (x - self.loc) / self.scale

@property
def mean(self):
"""
Returns the mean of the distribution, of shape (*batch_shape,)
"""
mu = torch.where(
self.xi < 1,
torch.div(self.beta, 1 - self.xi),
np.nan * torch.ones_like(self.xi),
)
return mu
concentration = self.concentration
valid = concentration < 1
safe_conc = torch.where(valid, concentration, 0.5)
result = self.loc + self.scale / (1 - safe_conc)
return torch.where(valid, result, torch.full_like(result, nan))

@property
def variance(self):
"""
Returns the variance of the distribution, of shape (*batch_shape,)
"""
xi, beta = self.xi, self.beta
var = torch.where(
xi < 1 / 2.0,
torch.div(beta**2, torch.mul((1 - xi) ** 2, (1 - 2 * xi))),
np.nan * torch.ones_like(xi),
)
return var
concentration = self.concentration
valid = concentration < 0.5
safe_conc = torch.where(valid, concentration, 0.25)
result = self.scale**2 / (
(1 - safe_conc) ** 2 * (1 - 2 * safe_conc)
) + torch.zeros_like(self.loc)
return torch.where(valid, result, torch.full_like(result, nan))

def entropy(self):
ans = torch.log(self.scale) + self.concentration + 1
return torch.broadcast_to(ans, self._batch_shape)

@property
def stddev(self):
return torch.sqrt(self.variance)

def log_prob(self, x):
"""
Log probability for a tensor x of shape (*batch_shape)
"""
# both xi and beta have shape (*batch_shape)
# and so do all the elements bellow

x = x.unsqueeze(dim=-1)

logp = -self.beta.log().double()
logp += torch.where(
self.xi == torch.zeros_like(self.xi),
-x / self.beta,
-(1 + 1.0 / (self.xi + 1e-6))
* torch.log(1 + self.xi * x / self.beta),
)
logp = torch.where(
x < torch.zeros_like(x),
(-np.inf * torch.ones_like(x)).double(),
logp,
def mode(self):
return self.loc

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self):
neg_conc = self.concentration < 0
upper = torch.where(
neg_conc,
self.loc - self.scale / self.concentration,
torch.full_like(self.loc, inf),
)
return logp.squeeze(dim=-1)

def cdf(self, x):
"""
cdf values for a tensor x of shape (*batch_shape)
"""
x = x.unsqueeze(dim=-1)
x_shifted = torch.div(x, self.beta)
u = 1 - torch.pow(1 + self.xi * x_shifted, -torch.reciprocal(self.xi))
return u.squeeze(dim=-1)

def icdf(self, value):
"""
icdf values for a tensor quantile values of shape (*batch_shape)
"""
value = value.unsqueeze(dim=-1)
x_shifted = torch.div(torch.pow(1 - value, -self.xi) - 1, self.xi)
x = torch.mul(x_shifted, self.beta)
return x.squeeze(dim=-1)
lower = self.loc
return constraints.interval(lower, upper)


class GeneralizedParetoOutput(DistributionOutput):
Expand All @@ -151,32 +189,18 @@ def __init__(

self.args_dim = cast(
Dict[str, int],
{
"xi": 1,
"beta": 1,
},
{"loc": 1, "scale": 1, "concentration": 1},
)

@classmethod
def domain_map( # type: ignore
def domain_map(
cls,
xi: torch.Tensor,
beta: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xi = torch.abs(xi)
beta = torch.abs(beta)

return xi, beta

def distribution(
self,
distr_args,
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> GeneralizedPareto:
return self.distr_cls(
*distr_args,
)
loc: torch.Tensor,
scale: torch.Tensor,
concentration: torch.Tensor,
): # type: ignore
scale = F.softplus(scale)
return loc.squeeze(-1), scale.squeeze(-1), concentration.squeeze(-1)

@property
def event_shape(self) -> Tuple:
Expand Down
Loading

0 comments on commit e2c2962

Please sign in to comment.