Skip to content

Commit c2dc496

Browse files
author
Ian
committed
small refacturing
1 parent 5706d1e commit c2dc496

4 files changed

Lines changed: 170 additions & 48 deletions

File tree

src/moving/standard.rs

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
use rand::prelude::SliceRandom;
1+
use crate::network::Network;
2+
use crate::network::grouping::NetworkGrouping;
23
use rand::RngCore;
4+
use rand::prelude::SliceRandom;
35
use single_utilities::traits::FloatOpsTS;
4-
use crate::network::grouping::NetworkGrouping;
5-
use crate::network::Network;
66

77
#[derive(Debug)]
88
pub struct StandardLocalMoving<T>
99
where
10-
T: FloatOpsTS, {
10+
T: FloatOpsTS,
11+
{
1112
resolution: T,
1213
cluster_weights: Vec<T>,
1314
nodes_per_cluster: Vec<usize>,
@@ -19,8 +20,8 @@ where
1920

2021
impl<T> StandardLocalMoving<T>
2122
where
22-
T: FloatOpsTS + 'static {
23-
23+
T: FloatOpsTS + 'static,
24+
{
2425
pub fn new(resolution: T) -> Self {
2526
StandardLocalMoving {
2627
resolution,
@@ -62,8 +63,7 @@ where
6263

6364
for i in 0..node_count {
6465
let cluster = clustering.get_group(i);
65-
self.cluster_weights[cluster] = self.cluster_weights[cluster]
66-
+ T::from_f64(network.weight(i).to_f64().unwrap()).unwrap();
66+
self.cluster_weights[cluster] += network.weight(i);
6767
self.nodes_per_cluster[cluster] += 1;
6868
}
6969

@@ -95,8 +95,7 @@ where
9595

9696
// Remove node from current cluster
9797
let node_weight = T::from_f64(network.weight(node).to_f64().unwrap()).unwrap();
98-
self.cluster_weights[current_cluster] =
99-
self.cluster_weights[current_cluster] - node_weight;
98+
self.cluster_weights[current_cluster] -= node_weight;
10099
self.nodes_per_cluster[current_cluster] -= 1;
101100

102101
if self.nodes_per_cluster[current_cluster] == 0 {
@@ -118,22 +117,21 @@ where
118117
self.neighboring_clusters[num_neighboring_clusters] = neighbor_cluster;
119118
num_neighboring_clusters += 1;
120119
}
121-
self.edge_weight_per_cluster[neighbor_cluster] =
122-
self.edge_weight_per_cluster[neighbor_cluster] + edge_weight;
120+
self.edge_weight_per_cluster[neighbor_cluster] += edge_weight;
123121
}
124122

125123
// Find best cluster
126124
let mut best_cluster = current_cluster;
127125
let mut max_quality_increment = self.edge_weight_per_cluster[current_cluster]
128126
- (node_weight * self.cluster_weights[current_cluster] * self.resolution)
129-
/ (T::from_f64(2.0).unwrap() * total_edge_weight);
127+
/ (T::from_f64(2.0).unwrap() * total_edge_weight);
130128

131129
//println!("ITERATION | Best Cluster {:?} Max Quality Increment {:?}", best_cluster, max_quality_increment.to_f64().unwrap());
132130

133131
for &cluster in &self.neighboring_clusters[..num_neighboring_clusters] {
134132
let quality_increment = self.edge_weight_per_cluster[cluster]
135133
- (node_weight * self.cluster_weights[cluster] * self.resolution)
136-
/ (T::from_f64(2.0).unwrap() * total_edge_weight);
134+
/ num_traits::Float::powi(T::from_f64(2.0).unwrap() * total_edge_weight, 2);
137135
//println!("ITERATION | Cluster {:?} Quality Increment {:?}", cluster, quality_increment.to_f64().unwrap());
138136
if quality_increment > max_quality_increment
139137
|| (quality_increment == max_quality_increment && cluster < best_cluster)
@@ -146,7 +144,7 @@ where
146144
}
147145

148146
// Update cluster assignment
149-
self.cluster_weights[best_cluster] = self.cluster_weights[best_cluster] + node_weight;
147+
self.cluster_weights[best_cluster] += node_weight;
150148
self.nodes_per_cluster[best_cluster] += 1;
151149

152150
if best_cluster == self.unused_clusters[num_unused_clusters - 1] {
@@ -170,5 +168,4 @@ where
170168

171169
update
172170
}
173-
174-
}
171+
}

src/neighborhood/mod.rs

Lines changed: 154 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,154 @@
1-
pub mod knn;
2-
pub mod connectivity;
1+
use kiddo::traits::DistanceMetric;
2+
use nalgebra_sparse::{CooMatrix, CsrMatrix};
3+
use ndarray::ArrayViewD;
4+
use single_utilities::traits::FloatOpsTS;
5+
6+
pub struct NeighborResult<T> {
7+
pub distances: CsrMatrix<T>,
8+
pub connectivities: CsrMatrix<T>,
9+
}
10+
11+
pub fn knn_arrayd_kiddo_gaussian<T, const K: usize, D>(
12+
data: ArrayViewD<T>,
13+
k: u64,
14+
) -> anyhow::Result<NeighborResult<T>>
15+
where
16+
T: FloatOpsTS + 'static,
17+
D: DistanceMetric<T, K>,
18+
{
19+
if data.ndim() != 2 {
20+
return Err(anyhow::anyhow!(
21+
"The input array has to have two dimensions."
22+
)); // TODO error message fix
23+
}
24+
25+
let shape = data.shape();
26+
let n_samples = shape[0] as u64;
27+
let n_features = shape[1] as u64;
28+
29+
if (n_features as usize) < K {
30+
return Err(anyhow::anyhow!(
31+
"The data must have at least K features in order to be used for KNN calculation"
32+
));
33+
}
34+
35+
let mut kdtree: kiddo::KdTree<T, K> = kiddo::KdTree::new();
36+
37+
for i in 0..n_samples {
38+
let mut point_array = [T::zero(); K];
39+
for j in 0..K {
40+
point_array[j] = *data.get([i as usize, j]).unwrap_or(&T::zero());
41+
}
42+
kdtree.add(&point_array, i);
43+
}
44+
45+
let mut knn_indices = Vec::with_capacity(n_samples as usize);
46+
let mut knn_distances_sq = Vec::with_capacity(n_samples as usize);
47+
48+
for i in 0..n_samples {
49+
let mut query_array = [T::zero(); K];
50+
for j in 0..K {
51+
query_array[j] = *data.get([i as usize, j]).unwrap_or(&T::zero());
52+
}
53+
54+
let neighbors = kdtree.nearest_n::<D>(&query_array, (k + 1) as usize);
55+
let mut indices = Vec::with_capacity(k as usize + 1);
56+
let mut distances_sq = Vec::with_capacity(k as usize + 1);
57+
58+
for neighbor in neighbors.iter() {
59+
indices.push(neighbor.item as usize);
60+
distances_sq.push(neighbor.distance);
61+
}
62+
63+
knn_indices.push(indices);
64+
knn_distances_sq.push(distances_sq);
65+
}
66+
67+
let mut distance_triplets = Vec::new();
68+
69+
for i in 0..n_samples as usize {
70+
for (idx, &j) in knn_indices[i].iter().enumerate() {
71+
distance_triplets.push((i, j, knn_distances_sq[i][idx]));
72+
}
73+
}
74+
75+
let mut sigmas_sq = Vec::with_capacity(n_samples as usize);
76+
77+
for i in 0..n_samples as usize {
78+
let mut dist_wo_self: Vec<T> = knn_distances_sq[i]
79+
.iter()
80+
.filter(|&&d| d > T::zero())
81+
.copied()
82+
.collect();
83+
84+
let sigma = if dist_wo_self.is_empty() {
85+
T::one()
86+
} else {
87+
dist_wo_self.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
88+
let median_idx = dist_wo_self.len() / 2;
89+
dist_wo_self[median_idx]
90+
};
91+
sigmas_sq.push(sigma);
92+
}
93+
94+
let mut connectivity_triplets = Vec::new();
95+
let min_weight = T::from_f64(1e-14).unwrap();
96+
97+
for i in 0..n_samples as usize {
98+
for &j in knn_indices[i].iter().skip(1) {
99+
if i <= j {
100+
// place here upper triangle restriction
101+
let dist_sq = if let Some(pos) = knn_indices[i].iter().position(|&x| x == j) {
102+
knn_distances_sq[i][pos]
103+
} else {
104+
continue;
105+
};
106+
107+
let sigma_i_sq = sigmas_sq[i];
108+
let sigma_j_sq = sigmas_sq[j];
109+
let sigma_i = sigma_i_sq.sqrt();
110+
let sigma_j = sigma_j_sq.sqrt();
111+
let num = T::from(2).unwrap() * sigma_i * sigma_j;
112+
let den = sigma_i_sq + sigma_j_sq;
113+
114+
let weight = if den > T::zero() {
115+
let normalization = (num / den).sqrt();
116+
let exponential = (-dist_sq / den).exp();
117+
normalization * exponential
118+
} else {
119+
T::zero()
120+
};
121+
122+
if weight > min_weight {
123+
connectivity_triplets.push((i, j, weight));
124+
if i != j {
125+
connectivity_triplets.push((j, i, weight)); // symmetry with just one computation step
126+
}
127+
}
128+
}
129+
}
130+
}
131+
132+
let distances_coo = CooMatrix::try_from_triplets(
133+
n_samples as usize,
134+
n_samples as usize,
135+
distance_triplets.iter().map(|&(i, _, _)| i).collect(),
136+
distance_triplets.iter().map(|&(_, j, _)| j).collect(),
137+
distance_triplets.iter().map(|&(_, _, v)| v).collect(),
138+
)
139+
.map_err(|e| anyhow::anyhow!("Failed to create distance COO matrix: {}", e))?;
140+
141+
let connectivities_coo = CooMatrix::try_from_triplets(
142+
n_samples as usize,
143+
n_samples as usize,
144+
connectivity_triplets.iter().map(|&(i, _, _)| i).collect(),
145+
connectivity_triplets.iter().map(|&(_, j, _)| j).collect(),
146+
connectivity_triplets.iter().map(|&(_, _, v)| v).collect(),
147+
)
148+
.map_err(|e| anyhow::anyhow!("Failed to create connectivity COO matrix: {}", e))?;
149+
150+
Ok(NeighborResult {
151+
distances: CsrMatrix::from(&distances_coo),
152+
connectivities: CsrMatrix::from(&connectivities_coo),
153+
})
154+
}

src/network/grouping.rs

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ pub trait NetworkGrouping: Debug + Send + Sync {
5050
}
5151

5252
#[derive(Debug, Clone)]
53+
#[derive(Default)]
5354
pub struct VectorGrouping {
5455
assignments: Vec<usize>,
5556
group_count: usize,
@@ -58,17 +59,6 @@ pub struct VectorGrouping {
5859
needs_size_update: bool,
5960
}
6061

61-
impl Default for VectorGrouping {
62-
fn default() -> Self {
63-
Self {
64-
assignments: Vec::new(),
65-
group_count: 0,
66-
group_sizes: Vec::new(),
67-
needs_size_update: false,
68-
}
69-
}
70-
}
71-
7262
impl VectorGrouping {
7363
#[inline]
7464
fn update_group_sizes(&mut self) {

src/network/mod.rs

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
1-
use crate::neighborhood::connectivity::GaussianConnectivity;
21
use crate::network::grouping::NetworkGrouping;
32
use nalgebra_sparse::{CooMatrix, CsrMatrix};
4-
use petgraph::data::DataMap;
53
use petgraph::graph::{Edges, UnGraph};
64
use petgraph::prelude::{EdgeRef, NodeIndex};
75
use rayon::iter::ParallelIterator;
@@ -23,7 +21,7 @@ pub struct NeighborAndWeightIterator<'a, N: 'a, E: 'a> {
2321
_phantom: std::marker::PhantomData<&'a N>,
2422
}
2523

26-
impl<'a, N, E> Iterator for NeighborAndWeightIterator<'a, N, E>
24+
impl<N, E> Iterator for NeighborAndWeightIterator<'_, N, E>
2725
where
2826
E: Copy,
2927
{
@@ -311,18 +309,3 @@ where
311309

312310
graph
313311
}
314-
315-
pub fn network_from_gaussian_connectivity<T>(
316-
distances: &CsrMatrix<T>,
317-
node_weights: Vec<T>,
318-
n_neighbors: usize,
319-
knn: bool,
320-
) -> Network<T, T>
321-
where
322-
T: FloatOpsTS + 'static,
323-
{
324-
let gauss_conn = GaussianConnectivity::new(knn);
325-
let connectivity_matrix = gauss_conn.compute_connectivities(distances, n_neighbors);
326-
let graph = csr_to_petgraph(connectivity_matrix, node_weights);
327-
Network::new_from_graph(graph)
328-
}

0 commit comments

Comments
 (0)