Skip to content

Commit 5706d1e

Browse files
author
Ian
committed
fixed KNN implementation to accurately reflect the scanpy behavior
1 parent a7c505f commit 5706d1e

12 files changed

Lines changed: 193 additions & 199 deletions

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
/target
22
.idea
3-
.idea/
3+
.idea/
4+
.fleet

.vscode/settings.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{
2+
"diffEditor.maxFileSize": 250,
3+
"rust-analyzer.hover.maxSubstitutionLength": 20,
4+
"rust-analyzer.numThreads": 64,
5+
"rust-analyzer.cachePriming.numThreads": 16,
6+
"rust-analyzer.check.command": "clippy",
7+
"rust-analyzer.lens.references.adt.enable": true,
8+
"rust-analyzer.cargo.features": [
9+
]
10+
}

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ description = "A high-performance network clustering library implementing commun
1212

1313
[dependencies]
1414
anyhow = "1.0.98"
15-
kiddo = "5.0.3"
15+
kiddo = {version = "5.0.3" }
1616
nalgebra-sparse = "0.10.0"
1717
ndarray = {version = "0.16.1" , features = ["rayon"]}
1818
num-traits = "0.2.19"

src/community_search/leiden.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ where
1919

2020
impl<T> Leiden<T>
2121
where
22-
T: FloatOpsTS,
22+
T: FloatOpsTS + 'static,
2323
{
2424
pub fn new(resolution: T, randomness: T, seed: Option<u64>) -> Self {
2525
let seed = seed.unwrap_or_default();

src/community_search/louvain.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ where
1717

1818
impl<T> Louvain<T>
1919
where
20-
T: FloatOpsTS,
20+
T: FloatOpsTS + 'static,
2121
{
2222
pub fn new(resolution: T, seed: Option<u64>) -> Self {
2323
let seed = seed.unwrap_or_default();

src/moving/fast.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ where
2424

2525
impl<T> FastLocalMoving<T>
2626
where
27-
T: FloatOpsTS,
27+
T: FloatOpsTS + 'static,
2828
{
2929
pub fn new(resolution: T) -> Self {
3030
FastLocalMoving {

src/moving/merging.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ where
2121

2222
impl<T> LocalMerging<T>
2323
where
24-
T: FloatOpsTS,
24+
T: FloatOpsTS + 'static,
2525
{
2626
pub fn new(resolution: T, randomness: T) -> Self {
2727
LocalMerging {

src/moving/standard.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ where
1919

2020
impl<T> StandardLocalMoving<T>
2121
where
22-
T: FloatOpsTS {
22+
T: FloatOpsTS + 'static {
2323

2424
pub fn new(resolution: T) -> Self {
2525
StandardLocalMoving {

src/neighborhood/connectivity.rs

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ where
1414

1515
impl<T> GaussianConnectivity<T>
1616
where
17-
T: FloatOpsTS,
17+
T: FloatOpsTS + 'static,
1818
{
1919
pub fn new(knn: bool) -> Self {
2020
Self {
@@ -40,15 +40,19 @@ where
4040
distances: &CsrMatrix<T>,
4141
n_neighbors: usize,
4242
) -> (Array2<usize>, Array2<T>) {
43-
4443
let n_obs = distances.nrows();
4544
let mut knn_indices = Array2::<usize>::zeros((n_obs, n_neighbors));
4645
let mut knn_distances = Array2::zeros((n_obs, n_neighbors));
4746

4847
for i in 0..n_obs {
4948
let mut neighbors: Vec<(usize, T)> = Vec::new();
5049

51-
for (col, &dist) in distances.row(i).col_indices().iter().zip(distances.row(i).values().iter()) {
50+
for (col, &dist) in distances
51+
.row(i)
52+
.col_indices()
53+
.iter()
54+
.zip(distances.row(i).values().iter())
55+
{
5256
if *col != i || dist > T::zero() {
5357
neighbors.push((*col, dist));
5458
}
@@ -72,7 +76,8 @@ where
7276

7377
for i in 0..n_obs {
7478
let sigma_sq = if self.knn {
75-
let mut distances_sq: Vec<T> = knn_distances.row(i)
79+
let mut distances_sq: Vec<T> = knn_distances
80+
.row(i)
7681
.iter()
7782
.filter(|&&d| d > T::zero())
7883
.map(|&d| d * d)
@@ -81,7 +86,8 @@ where
8186
if distances_sq.is_empty() {
8287
T::from_f64(1.0).unwrap()
8388
} else {
84-
distances_sq.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
89+
distances_sq
90+
.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
8591
let median_idx = distances_sq.len() / 2;
8692
distances_sq[median_idx]
8793
}
@@ -113,15 +119,18 @@ where
113119
for j in 0..n_neighbors {
114120
let neighbor_idx = knn_indices[[i, j]];
115121

116-
let pair = if i < neighbor_idx { (i, neighbor_idx) } else { (neighbor_idx, i) };
122+
let pair = if i < neighbor_idx {
123+
(i, neighbor_idx)
124+
} else {
125+
(neighbor_idx, i)
126+
};
117127
if processed_pairs.contains(&pair) {
118128
continue;
119129
}
120130
processed_pairs.insert(pair);
121131

122132
if let Some(dist_sq) = distances.get_entry(i, neighbor_idx) {
123133
let dist_sq = dist_sq.into_value();
124-
let dist_sq = dist_sq * dist_sq;
125134
let weight = self.compute_gaussian_weight(i, neighbor_idx, dist_sq, sigmas);
126135

127136
if weight > self.min_weight_threshold {
@@ -136,11 +145,11 @@ where
136145
} else {
137146
// For dense: compute all pairwise weights above threshold
138147
for i in 0..n_obs {
139-
for j in i..n_obs { // Only upper triangle, then make symmetric
148+
for j in i..n_obs {
149+
// Only upper triangle, then make symmetric
140150
if let Some(dist) = distances.get_entry(i, j) {
141151
let dist = dist.into_value();
142-
let dist_sq = dist * dist;
143-
let weight = self.compute_gaussian_weight(i, j, dist_sq, sigmas);
152+
let weight = self.compute_gaussian_weight(i, j, dist, sigmas);
144153

145154
if weight > self.min_weight_threshold {
146155
triplets.push((i, j, weight));
@@ -154,18 +163,13 @@ where
154163
}
155164

156165
// Convert to CSR matrix
157-
let mut rows = Vec::new();
158-
let mut cols = Vec::new();
159-
let mut data = Vec::new();
160-
161-
for (row, col, val) in triplets {
162-
rows.push(row);
163-
cols.push(col);
164-
data.push(val);
165-
}
166+
let rows: Vec<usize> = triplets.iter().map(|(r, _, _)| *r).collect();
167+
let cols: Vec<usize> = triplets.iter().map(|(_, c, _)| *c).collect();
168+
let data: Vec<T> = triplets.iter().map(|(_, _, v)| *v).collect();
166169

167-
CsrMatrix::try_from_csr_data(n_obs, n_obs, rows, cols, data)
168-
.expect("Failed to create Gaussian connectivity matrix")
170+
let coo = nalgebra_sparse::CooMatrix::try_from_triplets(n_obs, n_obs, rows, cols, data)
171+
.expect("Failed to create COO matrix");
172+
CsrMatrix::from(&coo)
169173
}
170174

171175
fn compute_gaussian_weight(&self, i: usize, j: usize, dist_sq: T, sigmas: &[T]) -> T {

0 commit comments

Comments
 (0)