Skip to content

Commit f834a27

Browse files
author
Ian
committed
Added optimited random SVD algorithm
1 parent b5ad644 commit f834a27

4 files changed

Lines changed: 161 additions & 84 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ description = "A Rust port of LAS2 from SVDLIBC"
44
keywords = ["svd"]
55
categories = ["algorithms", "data-structures", "mathematics", "science"]
66
name = "single-svdlib"
7-
version = "0.4.0"
7+
version = "0.5.0"
88
edition = "2021"
99
license-file = "SVDLIBC-LICENSE.txt"
1010

src/lib.rs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ mod simple_comparison_tests {
226226

227227
// Convert to CSR for processing
228228
let csr = CsrMatrix::from(&test_matrix);
229-
229+
230230
// Run randomized SVD with reasonable defaults for a sparse matrix
231231
let threadpool = ThreadPoolBuilder::new().num_threads(10).build().unwrap();
232232
let result = threadpool.install(|| {
@@ -241,6 +241,38 @@ mod simple_comparison_tests {
241241
});
242242

243243

244+
// Simply verify that the computation succeeds on a highly sparse matrix
245+
assert!(
246+
result.is_ok(),
247+
"Randomized SVD failed on 99% sparse matrix: {:?}",
248+
result.err().unwrap()
249+
);
250+
}
251+
252+
#[test]
253+
fn test_randomized_svd_small_sparse_matrix() {
254+
use crate::{randomized_svd, PowerIterationNormalizer};
255+
256+
// Create a very large matrix with high sparsity (99%)
257+
let test_matrix = create_sparse_matrix(1000, 250, 0.01); // 1% non-zeros
258+
259+
// Convert to CSR for processing
260+
let csr = CsrMatrix::from(&test_matrix);
261+
262+
// Run randomized SVD with reasonable defaults for a sparse matrix
263+
let threadpool = ThreadPoolBuilder::new().num_threads(10).build().unwrap();
264+
let result = threadpool.install(|| {
265+
randomized_svd(
266+
&csr,
267+
50, // target rank
268+
10, // oversampling parameter
269+
2, // power iterations
270+
PowerIterationNormalizer::QR, // use QR normalization
271+
Some(42), // random seed
272+
)
273+
});
274+
275+
244276
// Simply verify that the computation succeeds on a highly sparse matrix
245277
assert!(
246278
result.is_ok(),

src/randomized/mod.rs

Lines changed: 126 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
1-
use rayon::iter::ParallelIterator;
21
use crate::error::SvdLibError;
32
use crate::{Diagnostics, SMat, SvdFloat, SvdRec};
43
use nalgebra_sparse::na::{ComplexField, DMatrix, DVector, RealField};
5-
use ndarray::{Array1, Array2};
4+
use ndarray::Array1;
65
use nshare::IntoNdarray2;
76
use rand::prelude::{Distribution, StdRng};
87
use rand::SeedableRng;
98
use rand_distr::Normal;
10-
use std::ops::Mul;
11-
use rayon::current_num_threads;
9+
use rayon::iter::ParallelIterator;
1210
use rayon::prelude::{IndexedParallelIterator, IntoParallelIterator};
11+
use std::ops::Mul;
12+
use crate::utils::determine_chunk_size;
1313

1414
pub enum PowerIterationNormalizer {
1515
QR,
1616
LU,
1717
None,
1818
}
1919

20-
2120
const PARALLEL_THRESHOLD_ROWS: usize = 5000;
2221
const PARALLEL_THRESHOLD_COLS: usize = 1000;
2322
const PARALLEL_THRESHOLD_ELEMENTS: usize = 100_000;
@@ -35,22 +34,30 @@ where
3534
M: SMat<T>,
3635
T: ComplexField,
3736
{
37+
let start = std::time::Instant::now(); // only for debugging
3838
let m_rows = m.nrows();
3939
let m_cols = m.ncols();
4040

4141
let rank = target_rank.min(m_rows.min(m_cols));
4242
let l = rank + n_oversamples;
43+
println!("Basic statistics: {:?}", start.elapsed());
4344

4445
let omega = generate_random_matrix(m_cols, l, seed);
46+
println!("Generated Random Matrix here: {:?}", start.elapsed());
4547

4648
let mut y = DMatrix::<T>::zeros(m_rows, l);
4749
multiply_matrix(m, &omega, &mut y, false);
50+
println!(
51+
"First multiplication took: {:?}, Continuing for power iterations:",
52+
start.elapsed()
53+
);
4854

4955
if n_power_iters > 0 {
5056
let mut z = DMatrix::<T>::zeros(m_cols, l);
5157

52-
for _ in 0..n_power_iters {
58+
for w in 0..n_power_iters {
5359
multiply_matrix(m, &y, &mut z, true);
60+
println!("{}-nd power-iteration forward: {:?}", w, start.elapsed());
5461
match power_iteration_normalizer {
5562
PowerIterationNormalizer::QR => {
5663
let qr = z.qr();
@@ -61,8 +68,14 @@ where
6168
}
6269
PowerIterationNormalizer::None => {}
6370
}
71+
println!(
72+
"{}-nd power-iteration forward, normalization: {:?}",
73+
w,
74+
start.elapsed()
75+
);
6476

6577
multiply_matrix(m, &z, &mut y, false);
78+
println!("{}-nd power-iteration backward: {:?}", w, start.elapsed());
6679
match power_iteration_normalizer {
6780
PowerIterationNormalizer::QR => {
6881
let qr = y.qr();
@@ -71,16 +84,30 @@ where
7184
PowerIterationNormalizer::LU => normalize_columns(&mut y),
7285
PowerIterationNormalizer::None => {}
7386
}
87+
println!(
88+
"{}-nd power-iteration backward, normalization: {:?}",
89+
w,
90+
start.elapsed()
91+
);
7492
}
7593
}
76-
94+
println!(
95+
"Finished power-iteration, continuing QR: {:?}",
96+
start.elapsed()
97+
);
7798
let qr = y.qr();
99+
println!("QR finished: {:?}", start.elapsed());
78100
let q = qr.q();
79101

80102
let mut b = DMatrix::<T>::zeros(q.ncols(), m_cols);
81103
multiply_transposed_by_matrix(&q, m, &mut b);
104+
println!(
105+
"QMB matrix multiplication transposed: {:?}",
106+
start.elapsed()
107+
);
82108

83109
let svd = b.svd(true, true);
110+
println!("SVD decomposition took: {:?}", start.elapsed());
84111
let u_b = svd
85112
.u
86113
.ok_or_else(|| SvdLibError::Las2Error("SVD U computation failed".to_string()))?;
@@ -98,10 +125,15 @@ where
98125

99126
// Convert to the format required by SvdRec
100127
let d = actual_rank;
128+
println!("SVD Result Cropping: {:?}", start.elapsed());
101129

102130
let ut = u.transpose().into_ndarray2();
103-
let s = convert_singular_values(<DVector<T>>::from(singular_values.rows(0, actual_rank)), actual_rank);
131+
let s = convert_singular_values(
132+
<DVector<T>>::from(singular_values.rows(0, actual_rank)),
133+
actual_rank,
134+
);
104135
let vt = vt_subset.into_ndarray2();
136+
println!("Translation to ndarray: {:?}", start.elapsed());
105137

106138
Ok(SvdRec {
107139
d,
@@ -203,60 +235,32 @@ fn normalize_columns<T: SvdFloat + RealField + Send + Sync>(matrix: &mut DMatrix
203235
.collect();
204236

205237
// Apply normalization
206-
scales
207-
.iter()
208-
.for_each(|(j, scale)| {
209-
for i in 0..rows {
210-
let value = matrix.get_mut((i,*j)).unwrap();
211-
*value = value.clone() * scale.clone();
212-
}
213-
});
238+
scales.iter().for_each(|(j, scale)| {
239+
for i in 0..rows {
240+
let value = matrix.get_mut((i, *j)).unwrap();
241+
*value = value.clone() * scale.clone();
242+
}
243+
});
214244
}
215245

216246
// ----------------------------------------
217247
// Utils Functions
218248
// ----------------------------------------
219249

220-
221250
fn generate_random_matrix<T: SvdFloat + RealField>(
222251
rows: usize,
223252
cols: usize,
224253
seed: Option<u64>,
225254
) -> DMatrix<T> {
226-
//if rows < PARALLEL_THRESHOLD_ROWS && cols < PARALLEL_THRESHOLD_COLS && rows * cols < PARALLEL_THRESHOLD_ELEMENTS {
227-
let mut rng = match seed {
228-
Some(s) => StdRng::seed_from_u64(s),
229-
None => StdRng::seed_from_u64(0),
230-
};
231-
232-
let normal = Normal::new(0.0, 1.0).unwrap();
233-
return DMatrix::from_fn(rows, cols, |_, _| {
234-
T::from_f64(normal.sample(&mut rng)).unwrap()
235-
});
236-
//}
237-
238-
/*let seed_value = seed.unwrap_or(0);
239-
let mut matrix = DMatrix::<T>::zeros(rows, cols);
240-
let num_threads = current_num_threads();
241-
let chunk_size = (rows * cols + num_threads - 1) / num_threads;
242-
243-
(0..(rows * cols)).into_par_iter()
244-
.chunks(chunk_size)
245-
.enumerate()
246-
.for_each(|(chunk_idx, indices)| {
247-
let thread_seed = seed_value.wrapping_add(chunk_idx as u64);
248-
let mut rng = StdRng::seed_from_u64(thread_seed);
249-
let normal = Normal::new(0.0, 1.0).unwrap();
250-
for idx in indices {
251-
let i = idx / cols;
252-
let j = idx % cols;
253-
unsafe {
254-
*matrix.get_unchecked_mut((i, j)) = T::from_f64(normal.sample(&mut rng)).unwrap();
255-
}
256-
}
257-
});
258-
matrix*/
259-
255+
let mut rng = match seed {
256+
Some(s) => StdRng::seed_from_u64(s),
257+
None => StdRng::seed_from_u64(0),
258+
};
259+
260+
let normal = Normal::new(0.0, 1.0).unwrap();
261+
DMatrix::from_fn(rows, cols, |_, _| {
262+
T::from_f64(normal.sample(&mut rng)).unwrap()
263+
})
260264
}
261265

262266
fn multiply_matrix<T: SvdFloat, M: SMat<T>>(
@@ -266,53 +270,94 @@ fn multiply_matrix<T: SvdFloat, M: SMat<T>>(
266270
transpose_sparse: bool,
267271
) {
268272
let cols = dense.ncols();
269-
//let matrix_rows = if transpose_sparse { sparse.ncols() } else { sparse.nrows() };
270273

271-
//if matrix_rows < PARALLEL_THRESHOLD_ROWS && cols < PARALLEL_THRESHOLD_COLS {
272-
let mut col_vec = vec![T::zero(); dense.nrows()];
273-
let mut result_vec = vec![T::zero(); result.nrows()];
274+
let results: Vec<(usize, Vec<T>)> = (0..cols)
275+
.into_par_iter()
276+
.map(|j| {
277+
let mut col_vec = vec![T::zero(); dense.nrows()];
278+
let mut result_vec = vec![T::zero(); result.nrows()];
274279

275-
for j in 0..cols {
276-
// Extract column from dense matrix
277280
for i in 0..dense.nrows() {
278281
col_vec[i] = dense[(i, j)];
279282
}
280283

281-
// Perform sparse matrix operation
282284
sparse.svd_opa(&col_vec, &mut result_vec, transpose_sparse);
283285

284-
// Store results
285-
for i in 0..result.nrows() {
286-
result[(i, j)] = result_vec[i];
287-
}
286+
(j, result_vec)
287+
})
288+
.collect();
288289

289-
// Clear result vector for reuse
290-
result_vec.iter_mut().for_each(|v| *v = T::zero());
290+
for (j, col_result) in results {
291+
for i in 0..result.nrows() {
292+
result[(i, j)] = col_result[i];
291293
}
292-
return;
293-
//}
294-
295-
294+
}
296295
}
297296

298297
fn multiply_transposed_by_matrix<T: SvdFloat, M: SMat<T>>(
299-
q: &DMatrix<T>,
298+
q: &DMatrix<T>,
300299
sparse: &M,
301300
result: &mut DMatrix<T>,
302301
) {
303-
for j in 0..sparse.ncols() {
304-
let mut unit_vec = vec![T::zero(); sparse.ncols()];
305-
unit_vec[j] = T::one();
306-
307-
let mut col_vec = vec![T::zero(); sparse.nrows()];
308-
sparse.svd_opa(&unit_vec, &mut col_vec, false);
309-
310-
for i in 0..q.ncols() {
311-
let mut sum = T::zero();
312-
for k in 0..q.nrows() {
313-
sum += q[(k, i)] * col_vec[k];
302+
let q_rows = q.nrows();
303+
let q_cols = q.ncols();
304+
let sparse_rows = sparse.nrows();
305+
let sparse_cols = sparse.ncols();
306+
307+
eprintln!("Q dimensions: {} x {}", q_rows, q_cols);
308+
eprintln!("Sparse dimensions: {} x {}", sparse_rows, sparse_cols);
309+
eprintln!("Result dimensions: {} x {}", result.nrows(), result.ncols());
310+
311+
assert_eq!(
312+
q_rows, sparse_rows,
313+
"Dimension mismatch: Q has {} rows but sparse has {} rows",
314+
q_rows, sparse_rows
315+
);
316+
317+
assert_eq!(
318+
result.nrows(),
319+
q_cols,
320+
"Result matrix has incorrect row count: expected {}, got {}",
321+
q_cols,
322+
result.nrows()
323+
);
324+
assert_eq!(
325+
result.ncols(),
326+
sparse_cols,
327+
"Result matrix has incorrect column count: expected {}, got {}",
328+
sparse_cols,
329+
result.ncols()
330+
);
331+
332+
let chunk_size = determine_chunk_size(q_cols);
333+
334+
let chunk_results: Vec<Vec<(usize, Vec<T>)>> = (0..q_cols)
335+
.into_par_iter()
336+
.chunks(chunk_size)
337+
.map(|chunk| {
338+
let mut chunk_results = Vec::with_capacity(chunk.len());
339+
340+
for &col_idx in &chunk {
341+
let mut q_col = vec![T::zero(); q_rows];
342+
for i in 0..q_rows {
343+
q_col[i] = q[(i, col_idx)];
344+
}
345+
346+
let mut result_row = vec![T::zero(); sparse_cols];
347+
348+
sparse.svd_opa(&q_col, &mut result_row, true);
349+
350+
chunk_results.push((col_idx, result_row));
351+
}
352+
chunk_results
353+
})
354+
.collect();
355+
356+
for chunk_result in chunk_results {
357+
for (row_idx, row_values) in chunk_result {
358+
for j in 0..sparse_cols {
359+
result[(row_idx, j)] = row_values[j];
314360
}
315-
result[(i, j)] = sum;
316361
}
317362
}
318363
}

0 commit comments

Comments
 (0)