Prediction
This page describes how to perform conditional prediction with fitted Gaussian Mixture Models using StructuredGaussianMixtures.jl.
Overview
The predict
function computes posterior distributions over unobserved dimensions given observed values. This is useful for:
- Missing data imputation: Fill in missing values based on observed data
- Conditional sampling: Generate samples from the posterior distribution
- Uncertainty quantification: Assess uncertainty in predictions
Function Signature
predict(gmm, x; input_indices=1:length(x), output_indices=length(x)+1:length(gmm))
Parameters
gmm
: A fitted MixtureModelx
: Observed valuesinput_indices
: Indices of observed dimensions (default:1:length(x)
)output_indices
: Indices of dimensions to predict (default:length(x)+1:length(gmm)
)
Basic Usage
using StructuredGaussianMixtures
# Fit a GMM
data = randn(2, 1000)
gmm = fit(EM(3), data)
# Make predictions
query_point = [0.5] # Observed value for first dimension
posterior = predict(gmm, query_point) # Posterior over second dimension
Predict Functions
MvNormal Prediction
StructuredGaussianMixtures.predict
— Methodpredict(dist::MvNormal, x::AbstractVector, input_indices::Union{Vector{Int},AbstractRange}, output_indices::Union{Vector{Int},AbstractRange})
Compute the conditional distribution of the output indices given the input indices using the Schur complement. Returns a new MvNormal distribution representing the conditional distribution.
Arguments
dist
: The multivariate normal distributionx
: The observed values for the input indicesinput_indices
: The indices of the observed variablesoutput_indices
: The indices of the variables to predict
Returns
- A new MvNormal distribution representing the conditional distribution
LRDMvNormal Prediction
StructuredGaussianMixtures.predict
— Methodpredict(dist::LRDMvNormal, x::AbstractVector, input_indices::Union{Vector{Int},AbstractRange}, output_indices::Union{Vector{Int},AbstractRange})
Compute the conditional distribution of the output indices given the input indices using the Schur complement. Returns a new LRDMvNormal distribution representing the conditional distribution. This implementation is efficient for low-rank plus diagonal covariance structure.
Arguments
dist
: The low-rank plus diagonal multivariate normal distributionx
: The observed values for the input indicesinput_indices
: The indices of the observed variablesoutput_indices
: The indices of the variables to predict
Returns
- A new LRDMvNormal distribution representing the conditional distribution
MixtureModel Prediction
StructuredGaussianMixtures.predict
— Methodpredict(dist::MultivariateMixture, x::AbstractVector, input_indices::Union{Vector{Int},AbstractRange}, output_indices::Union{Vector{Int},AbstractRange})
Compute the conditional distribution of the output indices given the input indices for a mixture model. Returns a new mixture model where each component is the conditional distribution of the corresponding component, and the weights are updated based on the log density of x under the marginal distributions.
Arguments
dist
: The multivariate mixture distributionx
: The observed values for the input indicesinput_indices
: The indices of the observed variablesoutput_indices
: The indices of the variables to predict
Returns
- A new mixture model representing the conditional distribution
Convenience Function
StructuredGaussianMixtures.predict
— Methodpredict(dist::Union{MvNormal,LRDMvNormal,MultivariateMixture}, x::AbstractVector;
input_indices::Union{Vector{Int},AbstractRange} = 1:length(x),
output_indices::Union{Vector{Int},AbstractRange} = length(x)+1:length(mean(dist)))
Compute the conditional distribution of the output indices given the input indices using the Schur complement. Returns a new distribution representing the conditional distribution.
Arguments
dist
: The multivariate normal distribution (MvNormal or LRDMvNormal)x
: The observed values for the input indicesinput_indices
: The indices of the observed variables (default: first length(x) indices)output_indices
: The indices of the variables to predict (default: remaining indices)
Returns
- A new distribution representing the conditional distribution
Marginal Functions
Marginal Distribution
StructuredGaussianMixtures.marginal
— Methodmarginal(dist::MvNormal, indices::Union{Vector{Int},AbstractRange})
Compute the marginal distribution over the specified indices. Returns a new MvNormal distribution representing the marginal.
Arguments
dist
: The multivariate normal distributionindices
: The indices to marginalize over
Returns
- A new MvNormal distribution representing the marginal
StructuredGaussianMixtures.marginal
— Methodmarginal(dist::LRDMvNormal, indices::Union{Vector{Int},AbstractRange})
Compute the marginal distribution over the specified indices. Returns a new LRDMvNormal distribution representing the marginal.
Arguments
dist
: The low-rank plus diagonal multivariate normal distributionindices
: The indices to marginalize over
Returns
- A new LRDMvNormal distribution representing the marginal
Simple Examples
2D GMM Prediction
# Fit a GMM
data = randn(2, 1000)
gmm = fit(EM(3), data)
# Make prediction
x_query = [0.5]
posterior = predict(gmm, x_query)
# Generate samples from posterior
samples = rand(posterior, 100)
println("Posterior mean: ", mean(samples))
High-Dimensional Prediction
# Fit a high-dimensional GMM
data = randn(10, 1000)
gmm = fit(PCAEM(3, 3), data)
# Predict multiple dimensions
observed_dims = [1, 3, 5]
query_values = [0.5, -0.2, 1.1]
unobserved_dims = [2, 4, 6, 7, 8, 9, 10]
posterior = predict(gmm, query_values;
input_indices=observed_dims,
output_indices=unobserved_dims)
# Generate samples
samples = rand(posterior, 100)
println("Predicted dimensions shape: ", size(samples))
Partial Prediction
# Fit a 5D GMM
data = randn(5, 1000)
gmm = fit(EM(3), data)
# Observe dimensions 1 and 3, predict dimensions 2, 4, and 5
observed_values = [0.5, -0.2]
observed_dims = [1, 3]
output_dims = [2, 4, 5]
posterior = predict(gmm, observed_values;
input_indices=observed_dims,
output_indices=output_dims)
samples = rand(posterior, 100)
println("Predicted dimensions shape: ", size(samples))
Mathematical Background
Conditional Gaussian Distribution
For a Gaussian Mixture Model with components $k = 1, \ldots, K$, the posterior distribution given observed values $x_{obs}$ is:
\[p(x_{unobs} \mid x_{obs}) = \sum_{k=1}^K w^k \mathcal{N}(x_{unobs} \mid \mu^k_{unobs}, \Sigma^k_{unobs})\]
where $w^k$ are the posterior component weights, $\mu^k_{unobs}$ is the conditional mean for component $k$, and $\Sigma^k_{unobs}$ is the conditional covariance for component $k$
Component Weight Updates
The posterior component weights are computed as:
\[w^k = \frac{\pi^k \mathcal{N}(x_{obs} \mid \mu^k_{obs}, \Sigma^k_{obs})}{\sum_{j=1}^K \pi^j \mathcal{N}(x_{obs} \mid \mu^j_{obs}, \Sigma^j_{obs})}\]
Conditional Parameters
For each component $k$, the conditional parameters are:
\[\mu^k_{unobs} = \mu^k_{unobs} + \Sigma^k_{unobs,obs} (\Sigma^k_{obs})^{-1} (x_{obs} - \mu^k_{obs})\]
\[\Sigma^k_{unobs} = \Sigma^k_{unobs,unobs} - \Sigma^k_{unobs,obs} (\Sigma^k_{obs})^{-1} \Sigma^k_{obs,unobs}\]
Advanced Usage
Multiple Query Points
# Predict for multiple query points
query_points = [[0.5], [-0.2], [1.1]]
posteriors = [predict(gmm, q) for q in query_points]
# Generate samples for each
samples = [rand(p, 50) for p in posteriors]
Uncertainty Quantification
# Compute posterior statistics
posterior = predict(gmm, query_point)
samples = rand(posterior, 1000)
# Mean and variance
posterior_mean = mean(samples, dims=2)
posterior_var = var(samples, dims=2)
# Confidence intervals
posterior_quantiles = quantile(samples, [0.025, 0.975], dims=2)