Structured Gaussians
This page documents the LRDMvNormal
class, which represents low-rank plus diagonal multivariate normal distributions.
Overview
The LRDMvNormal
distribution represents a multivariate normal distribution with covariance matrix of the form $\Sigma = FF' + D$, where:
- F is a low-rank factor matrix of size $m \times r$ where $r \ll m$
- D is a diagonal matrix
- The full covariance matrix is never explicitly formed for efficiency
This structure is particularly useful in the following circumstances:
- High-dimensional data where $m \gg n$ (number of features much larger than number of samples)
- Low-rank structure in the data
- Memory constraints preventing full covariance storage
- Efficient sampling requirements
Constructor
StructuredGaussianMixtures.LRDMvNormal
— TypeLRDMvNormal
Low-rank plus diagonal multivariate normal distribution. The covariance matrix is represented as Σ = FF' + D, where F is a low-rank factor matrix and D is a diagonal matrix.
Fields
μ
: Mean vectorF
: Low-rank factor matrixD
: Diagonal vectorrank
: Rank of the low-rank component
Notes
- The covariance matrix is never explicitly formed
- All operations use the low-rank plus diagonal structure for efficiency
Usage
using StructuredGaussianMixtures
# Create a 10-dimensional distribution with rank-3 structure
μ = randn(10) # Mean vector
F = randn(10, 3) # Low-rank factor (10×3)
D = ones(10) # Diagonal vector
dist = LRDMvNormal(μ, F, D)
Distribution Interface
Basic Properties
Base.length
— Methodlength(d::LRDMvNormal)
Return the dimension of the distribution.
Base.size
— Methodsize(d::LRDMvNormal)
Return the size of the distribution as a tuple (dimension,).
Statistics.mean
— Methodmean(d::LRDMvNormal)
Return the mean vector of the distribution.
Statistics.cov
— Methodcov(d::LRDMvNormal)
Return the full covariance matrix FF' + D.
Low-Rank Structure Access
StructuredGaussianMixtures.rank
— Methodrank(d::LRDMvNormal)
Return the rank of the low-rank component.
StructuredGaussianMixtures.low_rank_factor
— Methodlow_rank_factor(d::LRDMvNormal)
Return the low-rank factor matrix F.
StructuredGaussianMixtures.diagonal
— Methoddiagonal(d::LRDMvNormal)
Return the diagonal vector D.
Probability and Sampling
Log Probability Density
Distributions.logpdf
— Methodlogpdf(d::LRDMvNormal, x::AbstractVector)
Compute the log probability density function at x. Uses the matrix inversion lemma for efficient computation.
Arguments
d
: The LRDMvNormal distributionx
: The point at which to evaluate the log PDF
Returns
- The log probability density at x
Notes
- Uses the matrix inversion lemma: (FF' + D)^(-1) = D^(-1) - D^(-1)F(I + F'D^(-1)F)^(-1)F'*D^(-1)
- Computes the determinant efficiently: det(FF' + D) = det(D) * det(I + F'D^(-1)*F)
Random Sampling
Distributions._rand!
— Method_rand!(rng::AbstractRNG, d::LRDMvNormal, x::VecOrMat)
Generate random samples in-place from the distribution.
Arguments
rng
: Random number generatord
: The LRDMvNormal distributionx
: Vector or matrix to fill with random samples
Returns
- The filled vector/matrix x
Notes
- Uses the decomposition: X = μ + FZ₁ + sqrt(D)Z₂ where Z₁, Z₂ are standard normal
- For matrices, each column is a sample
Distributions._rand!
— Method_rand!(rng::AbstractRNG, d::LRDMvNormal, x::AbstractVector)
Generate a random sample in-place from the distribution.
Arguments
rng
: Random number generatord
: The LRDMvNormal distributionx
: Vector to fill with random sample
Returns
- The filled vector x
Notes
- Uses the decomposition: X = μ + FZ₁ + sqrt(D)Z₂ where Z₁, Z₂ are standard normal
- Handles AbstractVector types that don't support randn!
Simple Examples
Creating and Using LRDMvNormal
using StructuredGaussianMixtures
# Create a low-rank distribution
n_features = 100
rank = 5
μ = randn(n_features)
F = randn(n_features, rank)
D = ones(n_features) * 0.1
dist = LRDMvNormal(μ, F, D)
# Basic properties
println("Dimension: ", length(dist))
println("Rank: ", rank(dist))
println("Mean: ", mean(dist)[1:5]) # First 5 elements
# Generate samples
samples = rand(dist, 1000)
println("Sample shape: ", size(samples))
# Compute log probability
log_prob = logpdf(dist, samples[:, 1])
println("Log probability: ", log_prob)
Efficient Operations
# The covariance matrix is never explicitly formed
# This is efficient even for high dimensions
n_features = 1000
rank = 10
μ = randn(n_features)
F = randn(n_features, rank)
D = ones(n_features) * 0.1
dist = LRDMvNormal(μ, F, D)
# This is efficient - no O(n³) operations
log_prob = logpdf(dist, randn(n_features))
# Sampling is also efficient
samples = rand(dist, 100)
Accessing Components
# Get the low-rank structure
F_matrix = low_rank_factor(dist)
D_vector = diagonal(dist)
rank_val = rank(dist)
println("F shape: ", size(F_matrix))
println("D length: ", length(D_vector))
println("Rank: ", rank_val)
# Full covariance (only for small dimensions!)
full_cov = cov(dist)
println("Full covariance shape: ", size(full_cov))
Mathematical Background
Efficient Log-Likelihood Computation
The log probability density is computed efficiently using block elimination and the matrix inversion lemma to compute the quadratic form, and the standard determinant formula for low-rank plus diagonal matrices. The approach avoids forming the full covariance matrix and reduces computational complexity from O(m³) to O(mr² + r³) where r ≪ m.
Sampling
Samples are generated using the decomposition:
\[X = \mu + FZ_1 + \sqrt{D}Z_2\]
where $Z_1$ and $Z_2$ are independent standard normal random vectors.
Performance Characteristics
Computational Complexity
Operation | Complexity | Notes |
---|---|---|
Log-likelihood | O(mr² + r³) | Uses block elimination with Schur complement |
Sampling | O(mr) | Efficient decomposition |
Memory | O(mr) | Stores F and D, not full covariance |
Memory Usage
For a distribution with $m$ features and rank $r$:
- Storage: $O(mr + m) = O(mr)$ for F and D
- Traditional: $O(m²)$ for full covariance matrix
- Savings: $O(m²/mr) = O(m/r)$ for typical $r \ll m$
Integration with GMMs
The LRDMvNormal
distribution is used internally by PCAEM
and FactorEM
methods:
# PCAEM creates LRDMvNormal components
gmm = fit(PCAEM(3, 5), data)
for comp in components(gmm)
println("Component rank: ", rank(comp))
end
# FactorEM also creates LRDMvNormal components
gmm = fit(FactorEM(3, 5), data)
for comp in components(gmm)
println("Component rank: ", rank(comp))
end