projnormal.classes
Classes for fitting the projected normal and related distributions.
- class ProjNormal(n_dim=None, mean_x=None, covariance_x=None)
Bases:
ModuleProjected normal distribution, describing the variable \(y=x/\|x\|\), where \(x \sim \mathcal{N}(\mu_x, \Sigma_x)\) follows a multivariate normal distribution.
- Parameters:
n_dim (
int) – Dimension of \(x\) (the embedding space). Optional: Ifmean_xandcovariance_xare provided, it is not required.mean_x (
torch.Tensor, optional) – Mean of \(x\). Shape(n_dim). Default is random.covariance_x (
torch.Tensor, optional) – Covariance of \(x\). Shape(n_dim, n_dim). Default is the identity.
- Variables:
mean_x (
torch.Tensor) – Mean of \(x\). Learnable parameter constrained to have unit norm. Shape(n_dim).covariance_x (
torch.Tensor) – Covariance of \(x\). Learnable parameter constrained to be SPD. Shape(n_dim, n_dim).
- log_pdf(y)
Compute the log pdf of points y.
- Parameters:
y (
torch.Tensor) – Points to evaluate the log pdf. Shape(n_points, n_dim).- Returns:
Log-PDF of y. Shape
(n_points).- Return type:
torch.Tensor
- max_likelihood(y, max_epochs=300, lr=0.1, optimizer='NAdam', show_progress=True, return_loss=False, n_cycles=1, cycle_gamma=0.5, **kwargs)
Fit the distribution parameters via maximum likelihood, by iteratively minimizing the negative log-likelihood of the data y.
- Parameters:
y (
torch.Tensor) – Observed data. Shape(n_samples, n_dim).max_epochs (
int) – Number of max training epochs. By default50.lr (
float) – Learning rate for the optimizer. Default is0.1.optimizer (
str) – Optimizer to use. Options are ‘LBFGS’ and ‘NAdam’. Default is ‘NAdam’.loss_fun (callable) – Loss function to use for moment matching. Default is squared error.
show_progress (
bool) – IfTrue, show a progress bar during training. Default isTrue.return_loss (
bool) – IfTrue, return the loss after training. Default isFalse.n_cycles (
int) – For the NAdam optimier, the number of times to run the optimization loop.cycle_gamma (
float) – Factor by whichlris reduced after each optimization loop repetition.**kwargs – Additional keyword arguments passed to the fitting function. For the NAdam optimizer, the parameters
gammaandstep_sizecan be passed to control the learning rate schedule.
- Returns:
Dictionary containing the loss and training time.
- Return type:
dict
- moment_init(data_moments)
Initialize the distribution parameters using the observed moments as the initial guess (after making the mean unit norm).
- Parameters:
data_moments (
dict) – Dictionary with keysmeanandcovariance, containing the observed mean and covariance of the data respectively.
- moment_match(data_moments, max_epochs=200, lr=0.1, optimizer='NAdam', loss_fun=None, show_progress=True, return_loss=False, n_cycles=3, cycle_gamma=0.5, **kwargs)
Fit the distribution parameters via moment matching.
- Parameters:
data_moments (
dict) – Dictionary containing the observed moments, with keysmeanandcovariance.max_epochs (
int) – Number of max training epochs. By default50.lr (
float) – Learning rate for the optimizer. Default is0.1.optimizer (
str) – Optimizer to use. Options are ‘LBFGS’ and ‘NAdam’. Default is ‘NAdam’.loss_fun (callable) – Loss function to use for moment matching. Default is squared error.
show_progress (
bool) – IfTrue, show a progress bar during training. Default isTrue.return_loss (
bool) – IfTrue, return the loss after training. Default isFalse.n_cycles (
int) – For the NAdam optimier, the number of times to run the optimization loop.cycle_gamma (
float) – Factor by whichlris reduced after each optimization loop repetition.**kwargs – Additional keyword arguments passed to the fitting function. For the NAdam optimizer, the parameters
gammaandstep_sizecan be passed to control the learning rate schedule.
- Returns:
Dictionary containing the loss and training time.
- Return type:
dict
- moments()
Compute moments of the distribution via Taylor approximation.
- Returns:
Dictionary with keys
mean,covarianceandsecond_moment, containing the corresponding moments of the distribution.- Return type:
dict
- moments_empirical(n_samples=200000)
Compute moments of the distribution via sampling.
- Parameters:
n_samples (
int) – Number of samples to draw for empirical moments. Default is200000.- Returns:
Dictionary with keys
mean,covarianceandsecond_moment, containing the corresponding moments of the distribution.- Return type:
dict
- pdf(y)
Compute the pdf of points y.
- Parameters:
y (
torch.Tensor) – Points to evaluate the pdf. Shape(n_points, n_dim).- Returns:
PDF of the points y. Shape
(n_points).- Return type:
torch.Tensor
- sample(n_samples)
Sample from the distribution.
- Parameters:
n_samples (
int) – Number of samples to draw.- Returns:
Samples from the distribution. Shape
(n_samples, n_dim).- Return type:
torch.Tensor
- class ProjNormalConst(n_dim=None, mean_x=None, covariance_x=None, const=None)
Bases:
ProjNormalProjected normal distribution variant, describing the variable \(y=x/\sqrt{||x||^2 + c}\), where \(x \sim \mathcal{N}(\mu_x, \Sigma_x)\) follows a multivariate normal distribution and \(c\) is a positive constant.
- Parameters:
n_dim (
int) – Dimension of \(x\) (the embedding space). Optional: Ifmean_xandcovariance_xare provided, it is not required.mean_x (
torch.Tensor, optional) – Mean of \(x\). Shape(n_dim). Default is random.covariance_x (
torch.Tensor, optional) – Covariance of \(x\). Shape(n_dim, n_dim). Default is the identity.const (
torch.Tensor, optional) – The denominator additive constant. Scalar. It is constrained to be positive. Default is1.0.
- Variables:
mean_x (
torch.Tensor) – Mean of \(x\). Learnable parameter constrained to have unit norm. Shape(n_dim).covariance_x (
torch.Tensor) – Covariance of \(x\). Learnable parameter constrained to be SPD. Shape(n_dim, n_dim).const (
torch.Tensor) – Denominator additive constant. Learnable parameter constained to be positive. Shape(1,).
- log_pdf(y)
Compute the log pdf of points y.
- Parameters:
y (
torch.Tensor) – Points to evaluate the log pdf. Shape(n_points, n_dim).- Returns:
Log-PDF of y. Shape
(n_points).- Return type:
torch.Tensor
- class ProjNormalEllipse(n_dim=None, mean_x=None, covariance_x=None, B=None)
Bases:
ProjNormalProjected normal distribution variant, describing the variable \(y=x/\sqrt{x^T B x}\), where \(x \sim \mathcal{N}(\mu_x, \Sigma_x)\) follows a multivariate normal distribution and \(B\) is a symmetric positive definite matrix.
- Parameters:
n_dim (
int) – Dimension of \(x\) (the embedding space). Optional: Ifmean_xandcovariance_xare provided, it is not required.mean_x (
torch.Tensor, optional) – Mean of \(x\). Shape(n_dim). Default is random.covariance_x (
torch.Tensor, optional) – Covariance of \(x\). Shape(n_dim, n_dim). Default is the identity.B (
torch.Tensor, optional) – SPD matrix defining the ellipse. Shape(n_dim, n_dim). Default is the identity matrix.
- Variables:
mean_x (
torch.Tensor) – Mean of \(x\). Learnable parameter constrained to have unit norm. Shape(n_dim).covariance_x (
torch.Tensor) – Covariance of \(x\). Learnable parameter constrained to be SPD. Shape(n_dim, n_dim).B (
torch.Tensor) – Quadratic form matrix of the denominator. Shape(n_dim, n_dim). Learnable parameter constrained to be SPD.
- log_pdf(y)
Compute the log pdf of points y.
- Parameters:
y (
torch.Tensor) – Points to evaluate the log pdf. Shape(n_points, n_dim).- Returns:
Log-PDF of y. Shape
(n_points).- Return type:
torch.Tensor
- moments()
Compute moments of the distribution via Taylor approximation.
- Returns:
Dictionary with keys
mean,covarianceandsecond_moment, containing the corresponding moments of the distribution.- Return type:
dict
- moments_empirical(n_samples=500000)
Compute moments of the distribution via sampling.
- Parameters:
n_samples (
int) – Number of samples to draw for empirical moments. Default is200000.- Returns:
Dictionary with keys
mean,covarianceandsecond_moment, containing the corresponding moments of the distribution.- Return type:
dict
- pdf(y)
Compute the pdf of points y.
- Parameters:
y (
torch.Tensor) – Points to evaluate the pdf. Shape(n_points, n_dim).- Returns:
PDF of the points y. Shape
(n_points).- Return type:
torch.Tensor
- sample(n_samples)
Sample from the distribution.
- Parameters:
n_samples (
int) – Number of samples to draw.- Returns:
Samples from the distribution. Shape
(n_samples, n_dim).- Return type:
torch.Tensor
- class ProjNormalEllipseConst(n_dim=None, mean_x=None, covariance_x=None, const=None, B=None)
Bases:
ProjNormalEllipseProjected normal distribution variant, describing the variable \(y=x/\sqrt{x^T B x + c}\), where \(x \sim \mathcal{N}(\mu_x, \Sigma_x)\) follows a multivariate normal distribution, \(B\) is a symmetric positive definite matrix, and \(c\) is a positive scalar constant.
- Parameters:
n_dim (
int) – Dimension of \(x\) (the embedding space). Optional: Ifmean_xandcovariance_xare provided, it is not required.mean_x (
torch.Tensor, optional) – Mean of \(x\). Shape(n_dim). Default is random.covariance_x (
torch.Tensor, optional) – Covariance of \(x\). Shape(n_dim, n_dim). Default is the identity.const (
torch.Tensor, optional) – The denominator additive constant. Shape(1,). Default is 1.B (
torch.Tensor, optional) – SPD matrix defining the ellipse. Shape(n_dim, n_dim). Default is the identity matrix.
- Variables:
mean_x (
torch.Tensor) – Mean of \(x\). Learnable parameter constrained to have unit norm. Shape(n_dim).covariance_x (
torch.Tensor) – Covariance of \(x\). Learnable parameter constrained to be SPD. Shape(n_dim, n_dim).const (
torch.Tensor) – Denominator additive constant. Shape(1,). Learnable parameter constained to be positive.B (
torch.Tensor) – Quadratic form matrix of the denominator. Shape(n_dim, n_dim). Learnable parameter constrained to be SPD.
- log_pdf(y)
Compute the log pdf of points y.
- Parameters:
y (
torch.Tensor) – Points to evaluate the log pdf. Shape(n_points, n_dim).- Returns:
Log-PDF of y. Shape
(n_points).- Return type:
torch.Tensor
Modules
Class for the general projected normal distribution with const denominator constant. |
|
Constraints to keep the distribution parameters in a valid region. |
|
Class for the general projected normal distribution with denominator matrix B. |
|
Cass for the general projected normal distribution. |
|
Class for the general projected normal distribution. |