Structured Gaussians

This page documents the LRDMvNormal class, which represents low-rank plus diagonal multivariate normal distributions.

Overview

The LRDMvNormal distribution represents a multivariate normal distribution with covariance matrix of the form $\Sigma = FF' + D$, where:

  • F is a low-rank factor matrix of size $m \times r$ where $r \ll m$
  • D is a diagonal matrix
  • The full covariance matrix is never explicitly formed for efficiency

This structure is particularly useful in the following circumstances:

  • High-dimensional data where $m \gg n$ (number of features much larger than number of samples)
  • Low-rank structure in the data
  • Memory constraints preventing full covariance storage
  • Efficient sampling requirements

Constructor

StructuredGaussianMixtures.LRDMvNormalType
LRDMvNormal

Low-rank plus diagonal multivariate normal distribution. The covariance matrix is represented as Σ = FF' + D, where F is a low-rank factor matrix and D is a diagonal matrix.

Fields

  • μ: Mean vector
  • F: Low-rank factor matrix
  • D: Diagonal vector
  • rank: Rank of the low-rank component

Notes

  • The covariance matrix is never explicitly formed
  • All operations use the low-rank plus diagonal structure for efficiency

Usage

using StructuredGaussianMixtures

# Create a 10-dimensional distribution with rank-3 structure
μ = randn(10)           # Mean vector
F = randn(10, 3)        # Low-rank factor (10×3)
D = ones(10)            # Diagonal vector
dist = LRDMvNormal(μ, F, D)

Distribution Interface

Basic Properties

Base.lengthMethod
length(d::LRDMvNormal)

Return the dimension of the distribution.

Base.sizeMethod
size(d::LRDMvNormal)

Return the size of the distribution as a tuple (dimension,).

Statistics.meanMethod
mean(d::LRDMvNormal)

Return the mean vector of the distribution.

Statistics.covMethod
cov(d::LRDMvNormal)

Return the full covariance matrix FF' + D.

Low-Rank Structure Access

Probability and Sampling

Log Probability Density

Distributions.logpdfMethod
logpdf(d::LRDMvNormal, x::AbstractVector)

Compute the log probability density function at x. Uses the matrix inversion lemma for efficient computation.

Arguments

  • d: The LRDMvNormal distribution
  • x: The point at which to evaluate the log PDF

Returns

  • The log probability density at x

Notes

  • Uses the matrix inversion lemma: (FF' + D)^(-1) = D^(-1) - D^(-1)F(I + F'D^(-1)F)^(-1)F'*D^(-1)
  • Computes the determinant efficiently: det(FF' + D) = det(D) * det(I + F'D^(-1)*F)

Random Sampling

Distributions._rand!Method
_rand!(rng::AbstractRNG, d::LRDMvNormal, x::VecOrMat)

Generate random samples in-place from the distribution.

Arguments

  • rng: Random number generator
  • d: The LRDMvNormal distribution
  • x: Vector or matrix to fill with random samples

Returns

  • The filled vector/matrix x

Notes

  • Uses the decomposition: X = μ + FZ₁ + sqrt(D)Z₂ where Z₁, Z₂ are standard normal
  • For matrices, each column is a sample
Distributions._rand!Method
_rand!(rng::AbstractRNG, d::LRDMvNormal, x::AbstractVector)

Generate a random sample in-place from the distribution.

Arguments

  • rng: Random number generator
  • d: The LRDMvNormal distribution
  • x: Vector to fill with random sample

Returns

  • The filled vector x

Notes

  • Uses the decomposition: X = μ + FZ₁ + sqrt(D)Z₂ where Z₁, Z₂ are standard normal
  • Handles AbstractVector types that don't support randn!

Simple Examples

Creating and Using LRDMvNormal

using StructuredGaussianMixtures

# Create a low-rank distribution
n_features = 100
rank = 5
μ = randn(n_features)
F = randn(n_features, rank)
D = ones(n_features) * 0.1

dist = LRDMvNormal(μ, F, D)

# Basic properties
println("Dimension: ", length(dist))
println("Rank: ", rank(dist))
println("Mean: ", mean(dist)[1:5])  # First 5 elements

# Generate samples
samples = rand(dist, 1000)
println("Sample shape: ", size(samples))

# Compute log probability
log_prob = logpdf(dist, samples[:, 1])
println("Log probability: ", log_prob)

Efficient Operations

# The covariance matrix is never explicitly formed
# This is efficient even for high dimensions
n_features = 1000
rank = 10

μ = randn(n_features)
F = randn(n_features, rank)
D = ones(n_features) * 0.1

dist = LRDMvNormal(μ, F, D)

# This is efficient - no O(n³) operations
log_prob = logpdf(dist, randn(n_features))

# Sampling is also efficient
samples = rand(dist, 100)

Accessing Components

# Get the low-rank structure
F_matrix = low_rank_factor(dist)
D_vector = diagonal(dist)
rank_val = rank(dist)

println("F shape: ", size(F_matrix))
println("D length: ", length(D_vector))
println("Rank: ", rank_val)

# Full covariance (only for small dimensions!)
full_cov = cov(dist)
println("Full covariance shape: ", size(full_cov))

Mathematical Background

Efficient Log-Likelihood Computation

The log probability density is computed efficiently using block elimination and the matrix inversion lemma to compute the quadratic form, and the standard determinant formula for low-rank plus diagonal matrices. The approach avoids forming the full covariance matrix and reduces computational complexity from O(m³) to O(mr² + r³) where r ≪ m.

Sampling

Samples are generated using the decomposition:

\[X = \mu + FZ_1 + \sqrt{D}Z_2\]

where $Z_1$ and $Z_2$ are independent standard normal random vectors.

Performance Characteristics

Computational Complexity

OperationComplexityNotes
Log-likelihoodO(mr² + r³)Uses block elimination with Schur complement
SamplingO(mr)Efficient decomposition
MemoryO(mr)Stores F and D, not full covariance

Memory Usage

For a distribution with $m$ features and rank $r$:

  • Storage: $O(mr + m) = O(mr)$ for F and D
  • Traditional: $O(m²)$ for full covariance matrix
  • Savings: $O(m²/mr) = O(m/r)$ for typical $r \ll m$

Integration with GMMs

The LRDMvNormal distribution is used internally by PCAEM and FactorEM methods:

# PCAEM creates LRDMvNormal components
gmm = fit(PCAEM(3, 5), data)
for comp in components(gmm)
    println("Component rank: ", rank(comp))
end

# FactorEM also creates LRDMvNormal components
gmm = fit(FactorEM(3, 5), data)
for comp in components(gmm)
    println("Component rank: ", rank(comp))
end