projnormal.classes.ellipse_const

Cass for the general projected normal distribution.

Classes

ProjNormalEllipseConst([n_dim, mean_x, ...])

Projected normal distribution variant, describing the variable \(y=x/\sqrt{x^T B x + c}\), where \(x \sim \mathcal{N}(\mu_x, \Sigma_x)\) follows a multivariate normal distribution, \(B\) is a symmetric positive definite matrix, and \(c\) is a positive scalar constant.

class ProjNormalEllipseConst(n_dim=None, mean_x=None, covariance_x=None, const=None, B=None)

Bases: ProjNormalEllipse

Projected normal distribution variant, describing the variable \(y=x/\sqrt{x^T B x + c}\), where \(x \sim \mathcal{N}(\mu_x, \Sigma_x)\) follows a multivariate normal distribution, \(B\) is a symmetric positive definite matrix, and \(c\) is a positive scalar constant.

Parameters:
  • n_dim (int) – Dimension of \(x\) (the embedding space). Optional: If mean_x and covariance_x are provided, it is not required.

  • mean_x (torch.Tensor, optional) – Mean of \(x\). Shape (n_dim). Default is random.

  • covariance_x (torch.Tensor, optional) – Covariance of \(x\). Shape (n_dim, n_dim). Default is the identity.

  • const (torch.Tensor, optional) – The denominator additive constant. Shape (1,). Default is 1.

  • B (torch.Tensor, optional) – SPD matrix defining the ellipse. Shape (n_dim, n_dim). Default is the identity matrix.

Variables:
  • mean_x (torch.Tensor) – Mean of \(x\). Learnable parameter constrained to have unit norm. Shape (n_dim).

  • covariance_x (torch.Tensor) – Covariance of \(x\). Learnable parameter constrained to be SPD. Shape (n_dim, n_dim).

  • const (torch.Tensor) – Denominator additive constant. Shape (1,). Learnable parameter constained to be positive.

  • B (torch.Tensor) – Quadratic form matrix of the denominator. Shape (n_dim, n_dim). Learnable parameter constrained to be SPD.

log_pdf(y)

Compute the log pdf of points y.

Parameters:

y (torch.Tensor) – Points to evaluate the log pdf. Shape (n_points, n_dim).

Returns:

Log-PDF of y. Shape (n_points).

Return type:

torch.Tensor