1414
1515impl < T > GaussianConnectivity < T >
1616where
17- T : FloatOpsTS ,
17+ T : FloatOpsTS + ' static ,
1818{
1919 pub fn new ( knn : bool ) -> Self {
2020 Self {
@@ -40,15 +40,19 @@ where
4040 distances : & CsrMatrix < T > ,
4141 n_neighbors : usize ,
4242 ) -> ( Array2 < usize > , Array2 < T > ) {
43-
4443 let n_obs = distances. nrows ( ) ;
4544 let mut knn_indices = Array2 :: < usize > :: zeros ( ( n_obs, n_neighbors) ) ;
4645 let mut knn_distances = Array2 :: zeros ( ( n_obs, n_neighbors) ) ;
4746
4847 for i in 0 ..n_obs {
4948 let mut neighbors: Vec < ( usize , T ) > = Vec :: new ( ) ;
5049
51- for ( col, & dist) in distances. row ( i) . col_indices ( ) . iter ( ) . zip ( distances. row ( i) . values ( ) . iter ( ) ) {
50+ for ( col, & dist) in distances
51+ . row ( i)
52+ . col_indices ( )
53+ . iter ( )
54+ . zip ( distances. row ( i) . values ( ) . iter ( ) )
55+ {
5256 if * col != i || dist > T :: zero ( ) {
5357 neighbors. push ( ( * col, dist) ) ;
5458 }
7276
7377 for i in 0 ..n_obs {
7478 let sigma_sq = if self . knn {
75- let mut distances_sq: Vec < T > = knn_distances. row ( i)
79+ let mut distances_sq: Vec < T > = knn_distances
80+ . row ( i)
7681 . iter ( )
7782 . filter ( |& & d| d > T :: zero ( ) )
7883 . map ( |& d| d * d)
8186 if distances_sq. is_empty ( ) {
8287 T :: from_f64 ( 1.0 ) . unwrap ( )
8388 } else {
84- distances_sq. sort_by ( |a, b| a. partial_cmp ( b) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
89+ distances_sq
90+ . sort_by ( |a, b| a. partial_cmp ( b) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
8591 let median_idx = distances_sq. len ( ) / 2 ;
8692 distances_sq[ median_idx]
8793 }
@@ -113,15 +119,18 @@ where
113119 for j in 0 ..n_neighbors {
114120 let neighbor_idx = knn_indices[ [ i, j] ] ;
115121
116- let pair = if i < neighbor_idx { ( i, neighbor_idx) } else { ( neighbor_idx, i) } ;
122+ let pair = if i < neighbor_idx {
123+ ( i, neighbor_idx)
124+ } else {
125+ ( neighbor_idx, i)
126+ } ;
117127 if processed_pairs. contains ( & pair) {
118128 continue ;
119129 }
120130 processed_pairs. insert ( pair) ;
121131
122132 if let Some ( dist_sq) = distances. get_entry ( i, neighbor_idx) {
123133 let dist_sq = dist_sq. into_value ( ) ;
124- let dist_sq = dist_sq * dist_sq;
125134 let weight = self . compute_gaussian_weight ( i, neighbor_idx, dist_sq, sigmas) ;
126135
127136 if weight > self . min_weight_threshold {
@@ -136,11 +145,11 @@ where
136145 } else {
137146 // For dense: compute all pairwise weights above threshold
138147 for i in 0 ..n_obs {
139- for j in i..n_obs { // Only upper triangle, then make symmetric
148+ for j in i..n_obs {
149+ // Only upper triangle, then make symmetric
140150 if let Some ( dist) = distances. get_entry ( i, j) {
141151 let dist = dist. into_value ( ) ;
142- let dist_sq = dist * dist;
143- let weight = self . compute_gaussian_weight ( i, j, dist_sq, sigmas) ;
152+ let weight = self . compute_gaussian_weight ( i, j, dist, sigmas) ;
144153
145154 if weight > self . min_weight_threshold {
146155 triplets. push ( ( i, j, weight) ) ;
@@ -154,18 +163,13 @@ where
154163 }
155164
156165 // Convert to CSR matrix
157- let mut rows = Vec :: new ( ) ;
158- let mut cols = Vec :: new ( ) ;
159- let mut data = Vec :: new ( ) ;
160-
161- for ( row, col, val) in triplets {
162- rows. push ( row) ;
163- cols. push ( col) ;
164- data. push ( val) ;
165- }
166+ let rows: Vec < usize > = triplets. iter ( ) . map ( |( r, _, _) | * r) . collect ( ) ;
167+ let cols: Vec < usize > = triplets. iter ( ) . map ( |( _, c, _) | * c) . collect ( ) ;
168+ let data: Vec < T > = triplets. iter ( ) . map ( |( _, _, v) | * v) . collect ( ) ;
166169
167- CsrMatrix :: try_from_csr_data ( n_obs, n_obs, rows, cols, data)
168- . expect ( "Failed to create Gaussian connectivity matrix" )
170+ let coo = nalgebra_sparse:: CooMatrix :: try_from_triplets ( n_obs, n_obs, rows, cols, data)
171+ . expect ( "Failed to create COO matrix" ) ;
172+ CsrMatrix :: from ( & coo)
169173 }
170174
171175 fn compute_gaussian_weight ( & self , i : usize , j : usize , dist_sq : T , sigmas : & [ T ] ) -> T {
0 commit comments