Skip to content

Commit 4d8a90e

Browse files
author
Ian
committed
some optimizations and reorganizing
1 parent b8671ba commit 4d8a90e

4 files changed

Lines changed: 74 additions & 104 deletions

File tree

src/chunked_loader.rs

Lines changed: 40 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -4,74 +4,45 @@ use anndata::{
44
ArrayElemOp,
55
};
66
use nalgebra_sparse::{pattern::SparsityPattern, CsrMatrix};
7+
use ndarray::Ix1;
78

8-
9-
10-
11-
12-
9+
use crate::{converter::LoadingConfig, utils::{read_array_as_usize_optimized, read_array_slice_as_usize}};
1310

1411
pub fn load_csr_chunked<B: Backend>(
1512
container: &DataContainer<B>,
1613
config: &LoadingConfig,
1714
) -> anyhow::Result<ArrayData> {
1815
let group = container.as_group()?;
1916
let shape: Vec<u64> = group.get_attr("shape")?;
20-
let nrows = shape[0] as usize;
21-
let ncols = shape[1] as usize;
17+
let (nrows, ncols) = (shape[0] as usize, shape[1] as usize);
2218

2319
let data_ds = group.open_dataset("data")?;
2420
let indices_ds = group.open_dataset("indices")?;
2521
let indptr_ds = group.open_dataset("indptr")?;
2622

27-
// Use the helper function to read indptr
28-
let indptr = read_array_as_usize::<B>(&indptr_ds)?;
29-
23+
let indptr = read_array_as_usize_optimized::<B>(&indptr_ds)?;
3024
let nnz = data_ds.shape()[0];
3125

3226
if config.show_progress && nnz > 10_000_000 {
33-
println!(
34-
"Loading CSR matrix: {} rows, {} cols, {} non-zeros",
35-
nrows, ncols, nnz
36-
);
27+
println!("Loading CSR matrix: {} rows, {} cols, {} non-zeros", nrows, ncols, nnz);
3728
}
3829

30+
use ScalarType::*;
3931
match data_ds.dtype()? {
40-
ScalarType::F64 => load_csr_typed::<B, f64>(
41-
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
42-
),
43-
ScalarType::F32 => load_csr_typed::<B, f32>(
44-
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
45-
),
46-
ScalarType::I64 => load_csr_typed::<B, i64>(
47-
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
48-
),
49-
ScalarType::I32 => load_csr_typed::<B, i32>(
50-
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
51-
),
52-
ScalarType::I16 => load_csr_typed::<B, i16>(
53-
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
54-
),
55-
ScalarType::I8 => load_csr_typed::<B, i8>(
56-
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
57-
),
58-
ScalarType::U64 => load_csr_typed::<B, u64>(
59-
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
60-
),
61-
ScalarType::U32 => load_csr_typed::<B, u32>(
62-
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
63-
),
64-
ScalarType::U16 => load_csr_typed::<B, u16>(
65-
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
66-
),
67-
ScalarType::U8 => load_csr_typed::<B, u8>(
68-
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
69-
),
32+
F64 => load_csr_typed::<B, f64>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
33+
F32 => load_csr_typed::<B, f32>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
34+
I64 => load_csr_typed::<B, i64>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
35+
I32 => load_csr_typed::<B, i32>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
36+
I16 => load_csr_typed::<B, i16>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
37+
I8 => load_csr_typed::<B, i8>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
38+
U64 => load_csr_typed::<B, u64>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
39+
U32 => load_csr_typed::<B, u32>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
40+
U16 => load_csr_typed::<B, u16>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
41+
U8 => load_csr_typed::<B, u8>(nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config),
7042
dt => anyhow::bail!("Unsupported data type for CSR matrix: {:?}", dt),
7143
}
7244
}
7345

74-
7546
fn load_csr_typed<B: Backend, T: anndata::backend::BackendData>(
7647
nrows: usize,
7748
ncols: usize,
@@ -82,69 +53,46 @@ fn load_csr_typed<B: Backend, T: anndata::backend::BackendData>(
8253
config: &LoadingConfig,
8354
) -> anyhow::Result<ArrayData>
8455
where
85-
anndata::ArrayData: std::convert::From<nalgebra_sparse::CsrMatrix<T>>
56+
ArrayData: From<CsrMatrix<T>>,
8657
{
87-
let chunk_size = (config.chunk_size_mb * 1_048_576) / (std::mem::size_of::<T>() + 8);
88-
let chunk_size = chunk_size.max(1000);
58+
let chunk_size = ((config.chunk_size_mb << 20) / (std::mem::size_of::<T>() + 8)).max(1000);
8959

9060
let mut data = Vec::with_capacity(nnz);
9161
let mut indices = Vec::with_capacity(nnz);
9262

63+
let show_progress = config.show_progress && nnz > 10_000_000;
64+
let progress_interval = if show_progress { nnz / 10 } else { usize::MAX };
65+
let mut next_progress = progress_interval;
66+
9367
let mut offset = 0;
94-
let mut last_progress = 0;
95-
9668
while offset < nnz {
9769
let chunk_end = (offset + chunk_size).min(nnz);
98-
99-
let data_array = data_ds.read_array_slice::<T, _, ndarray::Ix1>(&[SelectInfoElem::from(offset..chunk_end)])?;
100-
let data_chunk: Vec<T> = data_array.into_raw_vec();
101-
102-
103-
match indices_ds.dtype()? {
104-
ScalarType::U64 => {
105-
let indices_array = indices_ds.read_array_slice::<u64, _, ndarray::Ix1>(&[SelectInfoElem::from(offset..chunk_end)])?;
106-
let (indices_u64, _) = indices_array.into_raw_vec_and_offset();
107-
indices.extend(indices_u64.into_iter().map(|x| x as usize));
108-
}
109-
ScalarType::U32 => {
110-
let indices_array = indices_ds.read_array_slice::<u32, _, ndarray::Ix1>(&[SelectInfoElem::from(offset..chunk_end)])?;
111-
let (indices_u32, _) = indices_array.into_raw_vec_and_offset();
112-
indices.extend(indices_u32.into_iter().map(|x| x as usize));
113-
}
114-
ScalarType::I64 => {
115-
let indices_array = indices_ds.read_array_slice::<i64, _, ndarray::Ix1>(&[SelectInfoElem::from(offset..chunk_end)])?;
116-
let (indices_i64, _) = indices_array.into_raw_vec_and_offset();
117-
indices.extend(indices_i64.into_iter().map(|x| x as usize));
118-
}
119-
ScalarType::I32 => {
120-
let indices_array = indices_ds.read_array_slice::<i32, _, ndarray::Ix1>(&[SelectInfoElem::from(offset..chunk_end)])?;
121-
let (indices_i32, _) = indices_array.into_raw_vec_and_offset();
122-
indices.extend(indices_i32.into_iter().map(|x| x as usize));
123-
}
124-
_ => anyhow::bail!("Unsupported index type for CSR matrix"),
70+
let range = [SelectInfoElem::from(offset..chunk_end)];
71+
let data_array = data_ds.read_array_slice::<T, _, Ix1>(&range)?;
72+
let (data_vec, data_offset) = data_array.into_raw_vec_and_offset();
73+
if data_offset.is_none() {
74+
data.extend(data_vec);
75+
} else {
76+
data.extend(data_vec);
12577
}
12678

127-
data.extend(data_chunk);
79+
let indices_chunk = read_array_slice_as_usize::<B>(indices_ds, &range)?;
80+
indices.extend(indices_chunk);
12881

12982
offset = chunk_end;
13083

131-
if config.show_progress && nnz > 10_000_000 {
132-
let progress = (offset as f64 / nnz as f64 * 100.0) as usize;
133-
if progress >= last_progress + 10 {
134-
println!("Loading CSR matrix: {}%", progress);
135-
last_progress = progress;
136-
}
84+
if show_progress && offset >= next_progress {
85+
println!("Loading CSR matrix: {}%", offset * 100 / nnz);
86+
next_progress += progress_interval;
13787
}
13888
}
13989

140-
if config.show_progress && nnz > 10_000_000 {
90+
if show_progress {
14191
println!("Constructing CSR matrix structure...");
14292
}
14393

144-
let pattern = unsafe {
145-
SparsityPattern::from_offset_and_indices_unchecked(nrows, ncols, indptr, indices)
146-
};
147-
let csr = CsrMatrix::try_from_pattern_and_values(pattern, data).map_err(|e| anyhow::anyhow!("There was an error constructing the matrix {}", e))?;
148-
149-
Ok(ArrayData::from(csr))
150-
}
94+
let pattern = unsafe { SparsityPattern::from_offset_and_indices_unchecked(nrows, ncols, indptr, indices) };
95+
CsrMatrix::try_from_pattern_and_values(pattern, data)
96+
.map(ArrayData::from)
97+
.map_err(|e| anyhow::anyhow!("Failed to construct CSR matrix: {}", e))
98+
}

src/converter.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ops::{Deref, DerefMut};
33
use anndata::data::DataFrameIndex;
44
use anndata::{
55
AnnData, AnnDataOp, ArrayData, ArrayElemOp, AxisArrays, Backend, ElemCollection,
6+
ElemCollectionOp,
67
};
78
use anndata_hdf5::H5;
89
use anyhow::Ok;

src/optimized_loader.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
use anndata::{backend::{AttributeOp, DataContainer, DatasetOp, GroupOp, ScalarType}, ArrayData, Backend, Readable};
2+
use ndarray::Ix1;
3+
4+
use crate::utils::{build_csr_matrix, read_array_as_usize_optimized};
5+
16
pub fn load_csr_optimized<B: Backend>(
27
container: &DataContainer<B>,
38
) -> anyhow::Result<ArrayData> {
@@ -10,19 +15,17 @@ pub fn load_csr_optimized<B: Backend>(
1015
let indices_ds = group.open_dataset("indices")?;
1116
let indptr_ds = group.open_dataset("indptr")?;
1217

13-
// Use your existing function but optimize it to avoid iterator when possible
1418
let indptr = read_array_as_usize_optimized::<B>(&indptr_ds)?;
1519
let indices = read_array_as_usize_optimized::<B>(&indices_ds)?;
1620

17-
// Read data based on type - optimize to avoid copying when possible
1821
match data_ds.dtype()? {
1922
ScalarType::F64 => {
2023
let arr = data_ds.read_array::<f64, Ix1>()?;
2124
let (data, offset) = arr.into_raw_vec_and_offset();
2225
if offset.is_none() {
2326
build_csr_matrix(nrows, ncols, indptr, indices, data)
2427
} else {
25-
build_csr_matrix(nrows, ncols, indptr, indices, arr.to_vec())
28+
build_csr_matrix(nrows, ncols, indptr, indices, data)
2629
}
2730
}
2831
ScalarType::F32 => {
@@ -31,7 +34,7 @@ pub fn load_csr_optimized<B: Backend>(
3134
if offset.is_none() {
3235
build_csr_matrix(nrows, ncols, indptr, indices, data)
3336
} else {
34-
build_csr_matrix(nrows, ncols, indptr, indices, arr.to_vec())
37+
build_csr_matrix(nrows, ncols, indptr, indices, data)
3538
}
3639
}
3740
ScalarType::I64 => {
@@ -40,7 +43,7 @@ pub fn load_csr_optimized<B: Backend>(
4043
if offset.is_none() {
4144
build_csr_matrix(nrows, ncols, indptr, indices, data)
4245
} else {
43-
build_csr_matrix(nrows, ncols, indptr, indices, arr.to_vec())
46+
build_csr_matrix(nrows, ncols, indptr, indices, data)
4447
}
4548
}
4649
ScalarType::I32 => {
@@ -49,12 +52,12 @@ pub fn load_csr_optimized<B: Backend>(
4952
if offset.is_none() {
5053
build_csr_matrix(nrows, ncols, indptr, indices, data)
5154
} else {
52-
build_csr_matrix(nrows, ncols, indptr, indices, arr.to_vec())
55+
build_csr_matrix(nrows, ncols, indptr, indices, data)
5356
}
5457
}
5558
_ => {
56-
// Fallback to standard loading for other types
5759
anndata::data::ArrayData::read(container)
5860
}
5961
}
60-
}
62+
}
63+

src/utils/mod.rs

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{collections::HashMap, mem::replace};
22

3-
use anndata::{backend::{DataContainer, DatasetOp, GroupOp, ScalarType}, data::{DynCscMatrix, DynCsrMatrix, SelectInfoElem}, Backend};
3+
use anndata::{backend::{DataContainer, DatasetOp, GroupOp, ScalarType}, data::{DynCscMatrix, DynCsrMatrix, SelectInfoElem}, ArrayData, Backend};
44
use nalgebra_sparse::{pattern::SparsityPattern, CscMatrix, CsrMatrix};
55
use ndarray::Slice;
66

@@ -362,7 +362,7 @@ pub fn read_array_as_usize_optimized<B: Backend>(dataset: &B::Dataset) -> anyhow
362362
read_array_as_usize::<B>(dataset)
363363
}
364364

365-
fn read_array_as_usize<B: Backend>(dataset: &B::Dataset) -> anyhow::Result<Vec<usize>> {
365+
pub fn read_array_as_usize<B: Backend>(dataset: &B::Dataset) -> anyhow::Result<Vec<usize>> {
366366
match dataset.dtype()? {
367367
ScalarType::U64 => {
368368
let arr = dataset.read_array::<u64, ndarray::Ix1>()?;
@@ -400,7 +400,7 @@ fn read_array_as_usize<B: Backend>(dataset: &B::Dataset) -> anyhow::Result<Vec<u
400400
}
401401
}
402402

403-
fn read_array_slice_as_usize<B: Backend>(
403+
pub fn read_array_slice_as_usize<B: Backend>(
404404
dataset: &B::Dataset,
405405
selection: &[SelectInfoElem],
406406
) -> anyhow::Result<Vec<usize>> {
@@ -450,12 +450,30 @@ pub fn should_use_chunked_loading<B: Backend>(
450450
}
451451

452452
match container.encoding_type()? {
453-
anndata::backend::DataType::CsrMatrix(scalar_type) => {
453+
anndata::backend::DataType::CsrMatrix(_) => {
454454
let group = container.as_group()?;
455455
let nnz = group.open_dataset("data")?.shape()[0];
456456
let estimated_mb = (nnz * 16) / 1_048_576;
457457
Ok(estimated_mb > config.memory_threshold_mb)
458458
},
459459
_ => Ok(false)
460460
}
461+
}
462+
463+
pub fn build_csr_matrix<T>(
464+
nrows: usize,
465+
ncols: usize,
466+
indptr: Vec<usize>,
467+
indices: Vec<usize>,
468+
data: Vec<T>,
469+
) -> anyhow::Result<ArrayData>
470+
where
471+
CsrMatrix<T>: Into<ArrayData>,
472+
{
473+
// Use unsafe constructor since we trust the data from AnnData
474+
let pattern = unsafe {
475+
SparsityPattern::from_offset_and_indices_unchecked(nrows, ncols, indptr, indices)
476+
};
477+
let csr = CsrMatrix::try_from_pattern_and_values(pattern, data).map_err(|e| anyhow::anyhow!("Building the CSR encountered an error, {}", e))?;
478+
Ok(csr.into())
461479
}

0 commit comments

Comments
 (0)