projnormal.classes.projected_normal
Class for the general projected normal distribution.
Classes
|
Projected normal distribution, describing the variable \(y=x/\|x\|\), where \(x \sim \mathcal{N}(\mu_x, \Sigma_x)\) follows a multivariate normal distribution. |
- class ProjNormal(n_dim=None, mean_x=None, covariance_x=None)
Bases:
ModuleProjected normal distribution, describing the variable \(y=x/\|x\|\), where \(x \sim \mathcal{N}(\mu_x, \Sigma_x)\) follows a multivariate normal distribution.
- Parameters:
n_dim (
int) – Dimension of \(x\) (the embedding space). Optional: Ifmean_xandcovariance_xare provided, it is not required.mean_x (
torch.Tensor, optional) – Mean of \(x\). Shape(n_dim). Default is random.covariance_x (
torch.Tensor, optional) – Covariance of \(x\). Shape(n_dim, n_dim). Default is the identity.
- Variables:
mean_x (
torch.Tensor) – Mean of \(x\). Learnable parameter constrained to have unit norm. Shape(n_dim).covariance_x (
torch.Tensor) – Covariance of \(x\). Learnable parameter constrained to be SPD. Shape(n_dim, n_dim).
- log_pdf(y)
Compute the log pdf of points y.
- Parameters:
y (
torch.Tensor) – Points to evaluate the log pdf. Shape(n_points, n_dim).- Returns:
Log-PDF of y. Shape
(n_points).- Return type:
torch.Tensor
- max_likelihood(y, max_epochs=300, lr=0.1, optimizer='NAdam', show_progress=True, return_loss=False, n_cycles=1, cycle_gamma=0.5, **kwargs)
Fit the distribution parameters via maximum likelihood, by iteratively minimizing the negative log-likelihood of the data y.
- Parameters:
y (
torch.Tensor) – Observed data. Shape(n_samples, n_dim).max_epochs (
int) – Number of max training epochs. By default50.lr (
float) – Learning rate for the optimizer. Default is0.1.optimizer (
str) – Optimizer to use. Options are ‘LBFGS’ and ‘NAdam’. Default is ‘NAdam’.loss_fun (callable) – Loss function to use for moment matching. Default is squared error.
show_progress (
bool) – IfTrue, show a progress bar during training. Default isTrue.return_loss (
bool) – IfTrue, return the loss after training. Default isFalse.n_cycles (
int) – For the NAdam optimier, the number of times to run the optimization loop.cycle_gamma (
float) – Factor by whichlris reduced after each optimization loop repetition.**kwargs – Additional keyword arguments passed to the fitting function. For the NAdam optimizer, the parameters
gammaandstep_sizecan be passed to control the learning rate schedule.
- Returns:
Dictionary containing the loss and training time.
- Return type:
dict
- moment_init(data_moments)
Initialize the distribution parameters using the observed moments as the initial guess (after making the mean unit norm).
- Parameters:
data_moments (
dict) – Dictionary with keysmeanandcovariance, containing the observed mean and covariance of the data respectively.
- moment_match(data_moments, max_epochs=200, lr=0.1, optimizer='NAdam', loss_fun=None, show_progress=True, return_loss=False, n_cycles=3, cycle_gamma=0.5, **kwargs)
Fit the distribution parameters via moment matching.
- Parameters:
data_moments (
dict) – Dictionary containing the observed moments, with keysmeanandcovariance.max_epochs (
int) – Number of max training epochs. By default50.lr (
float) – Learning rate for the optimizer. Default is0.1.optimizer (
str) – Optimizer to use. Options are ‘LBFGS’ and ‘NAdam’. Default is ‘NAdam’.loss_fun (callable) – Loss function to use for moment matching. Default is squared error.
show_progress (
bool) – IfTrue, show a progress bar during training. Default isTrue.return_loss (
bool) – IfTrue, return the loss after training. Default isFalse.n_cycles (
int) – For the NAdam optimier, the number of times to run the optimization loop.cycle_gamma (
float) – Factor by whichlris reduced after each optimization loop repetition.**kwargs – Additional keyword arguments passed to the fitting function. For the NAdam optimizer, the parameters
gammaandstep_sizecan be passed to control the learning rate schedule.
- Returns:
Dictionary containing the loss and training time.
- Return type:
dict
- moments()
Compute moments of the distribution via Taylor approximation.
- Returns:
Dictionary with keys
mean,covarianceandsecond_moment, containing the corresponding moments of the distribution.- Return type:
dict
- moments_empirical(n_samples=200000)
Compute moments of the distribution via sampling.
- Parameters:
n_samples (
int) – Number of samples to draw for empirical moments. Default is200000.- Returns:
Dictionary with keys
mean,covarianceandsecond_moment, containing the corresponding moments of the distribution.- Return type:
dict
- pdf(y)
Compute the pdf of points y.
- Parameters:
y (
torch.Tensor) – Points to evaluate the pdf. Shape(n_points, n_dim).- Returns:
PDF of the points y. Shape
(n_points).- Return type:
torch.Tensor
- sample(n_samples)
Sample from the distribution.
- Parameters:
n_samples (
int) – Number of samples to draw.- Returns:
Samples from the distribution. Shape
(n_samples, n_dim).- Return type:
torch.Tensor