projnormal.classes.const
Class for the general projected normal distribution with const denominator constant.
Classes
|
Projected normal distribution variant, describing the variable \(y=x/\sqrt{||x||^2 + c}\), where \(x \sim \mathcal{N}(\mu_x, \Sigma_x)\) follows a multivariate normal distribution and \(c\) is a positive constant. |
- class ProjNormalConst(n_dim=None, mean_x=None, covariance_x=None, const=None)
Bases:
ProjNormalProjected normal distribution variant, describing the variable \(y=x/\sqrt{||x||^2 + c}\), where \(x \sim \mathcal{N}(\mu_x, \Sigma_x)\) follows a multivariate normal distribution and \(c\) is a positive constant.
- Parameters:
n_dim (
int) – Dimension of \(x\) (the embedding space). Optional: Ifmean_xandcovariance_xare provided, it is not required.mean_x (
torch.Tensor, optional) – Mean of \(x\). Shape(n_dim). Default is random.covariance_x (
torch.Tensor, optional) – Covariance of \(x\). Shape(n_dim, n_dim). Default is the identity.const (
torch.Tensor, optional) – The denominator additive constant. Scalar. It is constrained to be positive. Default is1.0.
- Variables:
mean_x (
torch.Tensor) – Mean of \(x\). Learnable parameter constrained to have unit norm. Shape(n_dim).covariance_x (
torch.Tensor) – Covariance of \(x\). Learnable parameter constrained to be SPD. Shape(n_dim, n_dim).const (
torch.Tensor) – Denominator additive constant. Learnable parameter constained to be positive. Shape(1,).
- log_pdf(y)
Compute the log pdf of points y.
- Parameters:
y (
torch.Tensor) – Points to evaluate the log pdf. Shape(n_points, n_dim).- Returns:
Log-PDF of y. Shape
(n_points).- Return type:
torch.Tensor