1- pub mod knn;
2- pub mod connectivity;
1+ use kiddo:: traits:: DistanceMetric ;
2+ use nalgebra_sparse:: { CooMatrix , CsrMatrix } ;
3+ use ndarray:: ArrayViewD ;
4+ use single_utilities:: traits:: FloatOpsTS ;
5+
6+ pub struct NeighborResult < T > {
7+ pub distances : CsrMatrix < T > ,
8+ pub connectivities : CsrMatrix < T > ,
9+ }
10+
11+ pub fn knn_arrayd_kiddo_gaussian < T , const K : usize , D > (
12+ data : ArrayViewD < T > ,
13+ k : u64 ,
14+ ) -> anyhow:: Result < NeighborResult < T > >
15+ where
16+ T : FloatOpsTS + ' static ,
17+ D : DistanceMetric < T , K > ,
18+ {
19+ if data. ndim ( ) != 2 {
20+ return Err ( anyhow:: anyhow!(
21+ "The input array has to have two dimensions."
22+ ) ) ; // TODO error message fix
23+ }
24+
25+ let shape = data. shape ( ) ;
26+ let n_samples = shape[ 0 ] as u64 ;
27+ let n_features = shape[ 1 ] as u64 ;
28+
29+ if ( n_features as usize ) < K {
30+ return Err ( anyhow:: anyhow!(
31+ "The data must have at least K features in order to be used for KNN calculation"
32+ ) ) ;
33+ }
34+
35+ let mut kdtree: kiddo:: KdTree < T , K > = kiddo:: KdTree :: new ( ) ;
36+
37+ for i in 0 ..n_samples {
38+ let mut point_array = [ T :: zero ( ) ; K ] ;
39+ for j in 0 ..K {
40+ point_array[ j] = * data. get ( [ i as usize , j] ) . unwrap_or ( & T :: zero ( ) ) ;
41+ }
42+ kdtree. add ( & point_array, i) ;
43+ }
44+
45+ let mut knn_indices = Vec :: with_capacity ( n_samples as usize ) ;
46+ let mut knn_distances_sq = Vec :: with_capacity ( n_samples as usize ) ;
47+
48+ for i in 0 ..n_samples {
49+ let mut query_array = [ T :: zero ( ) ; K ] ;
50+ for j in 0 ..K {
51+ query_array[ j] = * data. get ( [ i as usize , j] ) . unwrap_or ( & T :: zero ( ) ) ;
52+ }
53+
54+ let neighbors = kdtree. nearest_n :: < D > ( & query_array, ( k + 1 ) as usize ) ;
55+ let mut indices = Vec :: with_capacity ( k as usize + 1 ) ;
56+ let mut distances_sq = Vec :: with_capacity ( k as usize + 1 ) ;
57+
58+ for neighbor in neighbors. iter ( ) {
59+ indices. push ( neighbor. item as usize ) ;
60+ distances_sq. push ( neighbor. distance ) ;
61+ }
62+
63+ knn_indices. push ( indices) ;
64+ knn_distances_sq. push ( distances_sq) ;
65+ }
66+
67+ let mut distance_triplets = Vec :: new ( ) ;
68+
69+ for i in 0 ..n_samples as usize {
70+ for ( idx, & j) in knn_indices[ i] . iter ( ) . enumerate ( ) {
71+ distance_triplets. push ( ( i, j, knn_distances_sq[ i] [ idx] ) ) ;
72+ }
73+ }
74+
75+ let mut sigmas_sq = Vec :: with_capacity ( n_samples as usize ) ;
76+
77+ for i in 0 ..n_samples as usize {
78+ let mut dist_wo_self: Vec < T > = knn_distances_sq[ i]
79+ . iter ( )
80+ . filter ( |& & d| d > T :: zero ( ) )
81+ . copied ( )
82+ . collect ( ) ;
83+
84+ let sigma = if dist_wo_self. is_empty ( ) {
85+ T :: one ( )
86+ } else {
87+ dist_wo_self. sort_by ( |a, b| a. partial_cmp ( b) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
88+ let median_idx = dist_wo_self. len ( ) / 2 ;
89+ dist_wo_self[ median_idx]
90+ } ;
91+ sigmas_sq. push ( sigma) ;
92+ }
93+
94+ let mut connectivity_triplets = Vec :: new ( ) ;
95+ let min_weight = T :: from_f64 ( 1e-14 ) . unwrap ( ) ;
96+
97+ for i in 0 ..n_samples as usize {
98+ for & j in knn_indices[ i] . iter ( ) . skip ( 1 ) {
99+ if i <= j {
100+ // place here upper triangle restriction
101+ let dist_sq = if let Some ( pos) = knn_indices[ i] . iter ( ) . position ( |& x| x == j) {
102+ knn_distances_sq[ i] [ pos]
103+ } else {
104+ continue ;
105+ } ;
106+
107+ let sigma_i_sq = sigmas_sq[ i] ;
108+ let sigma_j_sq = sigmas_sq[ j] ;
109+ let sigma_i = sigma_i_sq. sqrt ( ) ;
110+ let sigma_j = sigma_j_sq. sqrt ( ) ;
111+ let num = T :: from ( 2 ) . unwrap ( ) * sigma_i * sigma_j;
112+ let den = sigma_i_sq + sigma_j_sq;
113+
114+ let weight = if den > T :: zero ( ) {
115+ let normalization = ( num / den) . sqrt ( ) ;
116+ let exponential = ( -dist_sq / den) . exp ( ) ;
117+ normalization * exponential
118+ } else {
119+ T :: zero ( )
120+ } ;
121+
122+ if weight > min_weight {
123+ connectivity_triplets. push ( ( i, j, weight) ) ;
124+ if i != j {
125+ connectivity_triplets. push ( ( j, i, weight) ) ; // symmetry with just one computation step
126+ }
127+ }
128+ }
129+ }
130+ }
131+
132+ let distances_coo = CooMatrix :: try_from_triplets (
133+ n_samples as usize ,
134+ n_samples as usize ,
135+ distance_triplets. iter ( ) . map ( |& ( i, _, _) | i) . collect ( ) ,
136+ distance_triplets. iter ( ) . map ( |& ( _, j, _) | j) . collect ( ) ,
137+ distance_triplets. iter ( ) . map ( |& ( _, _, v) | v) . collect ( ) ,
138+ )
139+ . map_err ( |e| anyhow:: anyhow!( "Failed to create distance COO matrix: {}" , e) ) ?;
140+
141+ let connectivities_coo = CooMatrix :: try_from_triplets (
142+ n_samples as usize ,
143+ n_samples as usize ,
144+ connectivity_triplets. iter ( ) . map ( |& ( i, _, _) | i) . collect ( ) ,
145+ connectivity_triplets. iter ( ) . map ( |& ( _, j, _) | j) . collect ( ) ,
146+ connectivity_triplets. iter ( ) . map ( |& ( _, _, v) | v) . collect ( ) ,
147+ )
148+ . map_err ( |e| anyhow:: anyhow!( "Failed to create connectivity COO matrix: {}" , e) ) ?;
149+
150+ Ok ( NeighborResult {
151+ distances : CsrMatrix :: from ( & distances_coo) ,
152+ connectivities : CsrMatrix :: from ( & connectivities_coo) ,
153+ } )
154+ }
0 commit comments