|
| 1 | +use rayon::iter::IndexedParallelIterator; |
1 | 2 | use crate::utils::determine_chunk_size; |
2 | 3 | use crate::{SMat, SvdFloat}; |
3 | 4 | use nalgebra_sparse::CsrMatrix; |
4 | 5 | use num_traits::Float; |
5 | 6 | use rayon::iter::ParallelIterator; |
6 | | -use rayon::prelude::{IntoParallelIterator, ParallelBridge}; |
| 7 | +use rayon::prelude::{IntoParallelIterator, ParallelBridge, ParallelSliceMut}; |
7 | 8 | use std::ops::AddAssign; |
8 | 9 |
|
9 | 10 | pub struct MaskedCSRMatrix<'a, T: Float> { |
@@ -86,7 +87,6 @@ impl<'a, T: Float + AddAssign + Sync + Send> SMat<T> for MaskedCSRMatrix<'a, T> |
86 | 87 | } |
87 | 88 |
|
88 | 89 | fn svd_opa(&self, x: &[T], y: &mut [T], transposed: bool) { |
89 | | - // TODO parallelize me please |
90 | 90 | let nrows = if transposed { |
91 | 91 | self.ncols() |
92 | 92 | } else { |
@@ -117,93 +117,87 @@ impl<'a, T: Float + AddAssign + Sync + Send> SMat<T> for MaskedCSRMatrix<'a, T> |
117 | 117 |
|
118 | 118 | y.fill(T::zero()); |
119 | 119 |
|
120 | | - let high_precision_mode = self.ensure_identical_results_mode(); |
121 | | - |
122 | 120 | if !transposed { |
123 | | - if high_precision_mode && self.uses_all_columns() { |
124 | | - // For small matrices using all columns, mimic the exact behavior of |
125 | | - // the original implementation to ensure identical results |
126 | | - for i in 0..self.matrix.nrows() { |
127 | | - let mut sum = T::zero(); |
128 | | - for j in major_offsets[i]..major_offsets[i + 1] { |
129 | | - let col = minor_indices[j]; |
130 | | - // For all-columns mode, we know all columns are included |
131 | | - let masked_col = self.original_to_masked[col].unwrap(); |
132 | | - sum = sum + (values[j] * x[masked_col]); |
133 | | - } |
134 | | - y[i] = sum; |
135 | | - } |
136 | | - } else { |
137 | | - let chunk_size = determine_chunk_size(self.matrix.nrows()); |
138 | | - y.chunks_mut(chunk_size).enumerate().par_bridge().for_each( |
139 | | - |(chunk_idx, y_chunk)| { |
140 | | - let start_row = chunk_idx * chunk_size; |
141 | | - let end_row = (start_row + y_chunk.len()).min(self.matrix.nrows()); |
142 | | - |
143 | | - for i in start_row..end_row { |
144 | | - let row_idx = i - start_row; |
145 | | - let mut sum = T::zero(); |
146 | | - |
147 | | - for j in major_offsets[i]..major_offsets[i + 1] { |
148 | | - let col = minor_indices[j]; |
149 | | - if let Some(masked_col) = self.original_to_masked[col] { |
150 | | - sum += values[j] * x[masked_col]; |
151 | | - }; |
152 | | - } |
153 | | - y_chunk[row_idx] = sum; |
154 | | - } |
155 | | - }, |
156 | | - ); |
| 121 | + // A * x calculation |
| 122 | + let row_count = self.matrix.nrows(); |
| 123 | + let (major_offsets, minor_indices, values) = self.matrix.csr_data(); |
| 124 | + |
| 125 | + let chunk_size = std::cmp::max(16, row_count / (rayon::current_num_threads() * 2)); |
| 126 | + |
| 127 | + let mut valid_indices = Vec::with_capacity(self.matrix.ncols()); |
| 128 | + for col in 0..self.matrix.ncols() { |
| 129 | + valid_indices.push(self.original_to_masked[col]); |
157 | 130 | } |
158 | | - } else { |
159 | | - // For the transposed case (A^T * x) |
160 | | - if high_precision_mode && self.uses_all_columns() { |
161 | | - // Clear the output vector first |
162 | | - for yval in y.iter_mut() { |
163 | | - *yval = T::zero(); |
164 | | - } |
165 | | - |
166 | | - // Follow exact same order of operations as original implementation |
167 | | - for i in 0..self.matrix.nrows() { |
168 | | - let row_val = x[i]; |
169 | | - for j in major_offsets[i]..major_offsets[i + 1] { |
170 | | - let col = minor_indices[j]; |
171 | | - let masked_col = self.original_to_masked[col].unwrap(); |
172 | | - y[masked_col] = y[masked_col] + (values[j] * row_val); |
173 | | - } |
174 | | - } |
175 | | - } else { |
176 | | - let nrows = self.matrix.nrows(); |
177 | | - let chunk_size = determine_chunk_size(nrows); |
178 | | - let num_chunks = (nrows + chunk_size - 1) / chunk_size; |
179 | | - let results: Vec<Vec<T>> = (0..chunk_size) |
180 | | - .into_par_iter() |
181 | | - .map(|chunk_idx| { |
182 | | - let start = chunk_idx * chunk_size; |
183 | | - let end = (start + chunk_size).min(nrows); |
184 | | - |
185 | | - let mut local_y = vec![T::zero(); y.len()]; |
186 | | - for i in start..end { |
187 | | - let row_val = x[i]; |
188 | | - for j in major_offsets[i]..major_offsets[i + 1] { |
189 | | - let col = minor_indices[j]; |
190 | | - if let Some(masked_col) = self.original_to_masked[col] { |
191 | | - local_y[masked_col] += values[j] * row_val; |
| 131 | + |
| 132 | + y.par_chunks_mut(chunk_size) |
| 133 | + .enumerate() |
| 134 | + .for_each(|(chunk_idx, y_chunk)| { |
| 135 | + let start_row = chunk_idx * chunk_size; |
| 136 | + let end_row = (start_row + y_chunk.len()).min(row_count); |
| 137 | + |
| 138 | + for i in start_row..end_row { |
| 139 | + let row_idx = i - start_row; |
| 140 | + let mut sum = T::zero(); |
| 141 | + |
| 142 | + let row_start = major_offsets[i]; |
| 143 | + let row_end = major_offsets[i + 1]; |
| 144 | + |
| 145 | + let mut j = row_start; |
| 146 | + |
| 147 | + while j + 4 <= row_end { |
| 148 | + for offset in 0..4 { |
| 149 | + let idx = j + offset; |
| 150 | + let col = minor_indices[idx]; |
| 151 | + if let Some(masked_col) = valid_indices[col] { |
| 152 | + sum += values[idx] * x[masked_col]; |
192 | 153 | } |
193 | 154 | } |
| 155 | + j += 4; |
| 156 | + } |
| 157 | + |
| 158 | + while j < row_end { |
| 159 | + let col = minor_indices[j]; |
| 160 | + if let Some(masked_col) = valid_indices[col] { |
| 161 | + sum += values[j] * x[masked_col]; |
| 162 | + } |
| 163 | + j += 1; |
194 | 164 | } |
195 | | - local_y |
196 | | - }) |
197 | | - .collect(); |
198 | | - |
199 | | - y.fill(T::zero()); |
200 | 165 |
|
201 | | - for local_y in results { |
202 | | - for (idx, val) in local_y.iter().enumerate() { |
203 | | - if !val.is_zero() { |
204 | | - y[idx] += *val; |
| 166 | + y_chunk[row_idx] = sum; |
| 167 | + } |
| 168 | + }); |
| 169 | + } else { |
| 170 | + // A^T * x calculation |
| 171 | + let nrows = self.matrix.nrows(); |
| 172 | + let chunk_size = determine_chunk_size(nrows); |
| 173 | + |
| 174 | + // Process in parallel chunks |
| 175 | + let results: Vec<Vec<T>> = (0..((nrows + chunk_size - 1) / chunk_size)) |
| 176 | + .into_par_iter() |
| 177 | + .map(|chunk_idx| { |
| 178 | + let start = chunk_idx * chunk_size; |
| 179 | + let end = (start + chunk_size).min(nrows); |
| 180 | + |
| 181 | + let mut local_y = vec![T::zero(); y.len()]; |
| 182 | + for i in start..end { |
| 183 | + let row_val = x[i]; |
| 184 | + for j in major_offsets[i]..major_offsets[i + 1] { |
| 185 | + let col = minor_indices[j]; |
| 186 | + if let Some(masked_col) = self.original_to_masked[col] { |
| 187 | + local_y[masked_col] += values[j] * row_val; |
| 188 | + } |
205 | 189 | } |
206 | 190 | } |
| 191 | + local_y |
| 192 | + }) |
| 193 | + .collect(); |
| 194 | + |
| 195 | + // Combine results |
| 196 | + for local_y in results { |
| 197 | + for (idx, val) in local_y.iter().enumerate() { |
| 198 | + if !val.is_zero() { |
| 199 | + y[idx] += *val; |
| 200 | + } |
207 | 201 | } |
208 | 202 | } |
209 | 203 | } |
|
0 commit comments