Skip to content

Commit b350468

Browse files
committed
Improved Json serialization for MCMCResults and UncertaintyAnalysisResults. Improved initialization and regularization of Gaussian Mixture Model.
1 parent fec21ca commit b350468

3 files changed

Lines changed: 74 additions & 15 deletions

File tree

Numerics/Distributions/Univariate/Uncertainty Analysis/UncertaintyAnalysisResults.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ public static byte[] ToByteArray(UncertaintyAnalysisResults results)
163163
{
164164
WriteIndented = false,
165165
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
166-
IncludeFields = true
166+
IncludeFields = true,
167+
NumberHandling = JsonNumberHandling.AllowNamedFloatingPointLiterals
167168
};
168169
// Add custom converters for unsupported types
169170
options.Converters.Add(new Double2DArrayConverter());
@@ -183,7 +184,8 @@ public static byte[] ToByteArray(UncertaintyAnalysisResults results)
183184
var options = new JsonSerializerOptions
184185
{
185186
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
186-
IncludeFields = true
187+
IncludeFields = true,
188+
NumberHandling = JsonNumberHandling.AllowNamedFloatingPointLiterals
187189
};
188190
// Add custom converters for unsupported types
189191
options.Converters.Add(new Double2DArrayConverter());

Numerics/Machine Learning/Unsupervised/GaussianMixtureModel.cs

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -186,22 +186,64 @@ public GaussianMixtureModel(Matrix X, int k)
186186
/// <param name="kMeansPlusPlus">Determines whether to use random initialization or to use the k-Means++ method. Default is to use k-Means++.</param>
187187
public void Train(int seed = -1, bool kMeansPlusPlus = true)
188188
{
189-
// 1. Initialize clusters from k-Means
189+
// 1. Initialize clusters from k-Means
190190
var kMeans = new KMeans(X, K);
191191
kMeans.Train(seed, kMeansPlusPlus);
192192
Means = kMeans.Means;
193193

194194
// Give equal weight to each cluster
195-
// And initialize the covariance matrix
195+
// and initialize the covariance matrix from cluster data variance.
196+
// Using actual data variance (instead of a tiny constant like 1e-10)
197+
// prevents the first E-step from computing likelihoods from near-degenerate
198+
// Gaussians, which can cause numerical overflow or underflow.
196199
Weights = new double[K];
197200
LikelihoodMatrix = new double[X.NumberOfRows, K];
198201
Sigmas = new Matrix[K];
199202
for (int k = 0; k < K; k++)
200-
{
203+
{
201204
Weights[k] = 1d / K;
202205
Sigmas[k] = new Matrix(Dimension);
203-
for (int i = 0; i < Dimension; i++)
204-
Sigmas[k][i, i] = 1E-10;
206+
207+
// Compute within-cluster covariance from k-means labels
208+
int clusterCount = 0;
209+
for (int i = 0; i < X.NumberOfRows; i++)
210+
{
211+
if (kMeans.Labels[i] == k)
212+
clusterCount++;
213+
}
214+
215+
if (clusterCount > 1)
216+
{
217+
// Compute covariance from cluster members
218+
for (int d = 0; d < Dimension; d++)
219+
{
220+
for (int j = 0; j < Dimension; j++)
221+
{
222+
double sum = 0;
223+
for (int i = 0; i < X.NumberOfRows; i++)
224+
{
225+
if (kMeans.Labels[i] == k)
226+
sum += (X[i, d] - Means[k, d]) * (X[i, j] - Means[k, j]);
227+
}
228+
Sigmas[k][d, j] = sum / clusterCount;
229+
}
230+
}
231+
}
232+
233+
// Ensure positive-definite: floor diagonal at fraction of overall data variance
234+
for (int d = 0; d < Dimension; d++)
235+
{
236+
double colVar = 0;
237+
double colMean = 0;
238+
for (int i = 0; i < X.NumberOfRows; i++)
239+
colMean += X[i, d];
240+
colMean /= X.NumberOfRows;
241+
for (int i = 0; i < X.NumberOfRows; i++)
242+
colVar += (X[i, d] - colMean) * (X[i, d] - colMean);
243+
colVar /= X.NumberOfRows;
244+
245+
Sigmas[k][d, d] = Math.Max(Sigmas[k][d, d], 1E-6 * colVar);
246+
}
205247
}
206248

207249
// 2. Optimize clusters
@@ -318,13 +360,26 @@ private void MStep()
318360
}
319361
}
320362

321-
// Add small regularization to the diagonal to ensure the covariance
322-
// matrix remains positive-definite. When a component collapses to very
323-
// few points, the covariance can become singular, causing Cholesky
324-
// decomposition in the E-step to fail.
363+
// Floor diagonal at a fraction of the overall data variance to prevent
364+
// component collapse. When a component captures very few points, its
365+
// covariance can become singular, causing Cholesky decomposition in the
366+
// E-step to fail. This mirrors sklearn's reg_covar parameter.
325367
for (int d = 0; d < Dimension; d++)
326-
MatrixRegularization.MakeSymmetricPositiveDefinite(Sigmas[k]);
327-
368+
{
369+
double colVar = 0;
370+
double colMean = 0;
371+
for (int i = 0; i < X.NumberOfRows; i++)
372+
colMean += X[i, d];
373+
colMean /= X.NumberOfRows;
374+
for (int i = 0; i < X.NumberOfRows; i++)
375+
colVar += (X[i, d] - colMean) * (X[i, d] - colMean);
376+
colVar /= X.NumberOfRows;
377+
378+
Sigmas[k][d, d] = Math.Max(Sigmas[k][d, d], 1E-6 * colVar);
379+
}
380+
381+
// Ensure the full covariance matrix remains symmetric positive-definite
382+
MatrixRegularization.MakeSymmetricPositiveDefinite(Sigmas[k]);
328383
}
329384
}
330385

Numerics/Sampling/MCMC/Support/MCMCResults.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ public static byte[] ToByteArray(MCMCResults mcmcResults)
192192
{
193193
WriteIndented = false,
194194
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
195-
IncludeFields = true
195+
IncludeFields = true,
196+
NumberHandling = JsonNumberHandling.AllowNamedFloatingPointLiterals
196197
};
197198
options.Converters.Add(new Double2DArrayConverter());
198199
options.Converters.Add(new HistogramConverter());
@@ -208,7 +209,8 @@ public static byte[] ToByteArray(MCMCResults mcmcResults)
208209
var options = new JsonSerializerOptions
209210
{
210211
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
211-
IncludeFields = true
212+
IncludeFields = true,
213+
NumberHandling = JsonNumberHandling.AllowNamedFloatingPointLiterals
212214
};
213215
options.Converters.Add(new Double2DArrayConverter());
214216
options.Converters.Add(new HistogramConverter());

0 commit comments

Comments
 (0)