Skip to content

Commit b8671ba

Browse files
author
Ian
committed
optimized loading
1 parent 28f3062 commit b8671ba

5 files changed

Lines changed: 369 additions & 2 deletions

File tree

src/chunked_loader.rs

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
use anndata::{
2+
backend::{AttributeOp, Backend, DataContainer, DatasetOp, GroupOp, ScalarType},
3+
data::{ArrayData, SelectInfoElem},
4+
ArrayElemOp,
5+
};
6+
use nalgebra_sparse::{pattern::SparsityPattern, CsrMatrix};
7+
8+
9+
10+
11+
12+
13+
14+
pub fn load_csr_chunked<B: Backend>(
15+
container: &DataContainer<B>,
16+
config: &LoadingConfig,
17+
) -> anyhow::Result<ArrayData> {
18+
let group = container.as_group()?;
19+
let shape: Vec<u64> = group.get_attr("shape")?;
20+
let nrows = shape[0] as usize;
21+
let ncols = shape[1] as usize;
22+
23+
let data_ds = group.open_dataset("data")?;
24+
let indices_ds = group.open_dataset("indices")?;
25+
let indptr_ds = group.open_dataset("indptr")?;
26+
27+
// Use the helper function to read indptr
28+
let indptr = read_array_as_usize::<B>(&indptr_ds)?;
29+
30+
let nnz = data_ds.shape()[0];
31+
32+
if config.show_progress && nnz > 10_000_000 {
33+
println!(
34+
"Loading CSR matrix: {} rows, {} cols, {} non-zeros",
35+
nrows, ncols, nnz
36+
);
37+
}
38+
39+
match data_ds.dtype()? {
40+
ScalarType::F64 => load_csr_typed::<B, f64>(
41+
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
42+
),
43+
ScalarType::F32 => load_csr_typed::<B, f32>(
44+
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
45+
),
46+
ScalarType::I64 => load_csr_typed::<B, i64>(
47+
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
48+
),
49+
ScalarType::I32 => load_csr_typed::<B, i32>(
50+
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
51+
),
52+
ScalarType::I16 => load_csr_typed::<B, i16>(
53+
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
54+
),
55+
ScalarType::I8 => load_csr_typed::<B, i8>(
56+
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
57+
),
58+
ScalarType::U64 => load_csr_typed::<B, u64>(
59+
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
60+
),
61+
ScalarType::U32 => load_csr_typed::<B, u32>(
62+
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
63+
),
64+
ScalarType::U16 => load_csr_typed::<B, u16>(
65+
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
66+
),
67+
ScalarType::U8 => load_csr_typed::<B, u8>(
68+
nrows, ncols, nnz, indptr, &data_ds, &indices_ds, config,
69+
),
70+
dt => anyhow::bail!("Unsupported data type for CSR matrix: {:?}", dt),
71+
}
72+
}
73+
74+
75+
fn load_csr_typed<B: Backend, T: anndata::backend::BackendData>(
76+
nrows: usize,
77+
ncols: usize,
78+
nnz: usize,
79+
indptr: Vec<usize>,
80+
data_ds: &B::Dataset,
81+
indices_ds: &B::Dataset,
82+
config: &LoadingConfig,
83+
) -> anyhow::Result<ArrayData>
84+
where
85+
anndata::ArrayData: std::convert::From<nalgebra_sparse::CsrMatrix<T>>
86+
{
87+
let chunk_size = (config.chunk_size_mb * 1_048_576) / (std::mem::size_of::<T>() + 8);
88+
let chunk_size = chunk_size.max(1000);
89+
90+
let mut data = Vec::with_capacity(nnz);
91+
let mut indices = Vec::with_capacity(nnz);
92+
93+
let mut offset = 0;
94+
let mut last_progress = 0;
95+
96+
while offset < nnz {
97+
let chunk_end = (offset + chunk_size).min(nnz);
98+
99+
let data_array = data_ds.read_array_slice::<T, _, ndarray::Ix1>(&[SelectInfoElem::from(offset..chunk_end)])?;
100+
let data_chunk: Vec<T> = data_array.into_raw_vec();
101+
102+
103+
match indices_ds.dtype()? {
104+
ScalarType::U64 => {
105+
let indices_array = indices_ds.read_array_slice::<u64, _, ndarray::Ix1>(&[SelectInfoElem::from(offset..chunk_end)])?;
106+
let (indices_u64, _) = indices_array.into_raw_vec_and_offset();
107+
indices.extend(indices_u64.into_iter().map(|x| x as usize));
108+
}
109+
ScalarType::U32 => {
110+
let indices_array = indices_ds.read_array_slice::<u32, _, ndarray::Ix1>(&[SelectInfoElem::from(offset..chunk_end)])?;
111+
let (indices_u32, _) = indices_array.into_raw_vec_and_offset();
112+
indices.extend(indices_u32.into_iter().map(|x| x as usize));
113+
}
114+
ScalarType::I64 => {
115+
let indices_array = indices_ds.read_array_slice::<i64, _, ndarray::Ix1>(&[SelectInfoElem::from(offset..chunk_end)])?;
116+
let (indices_i64, _) = indices_array.into_raw_vec_and_offset();
117+
indices.extend(indices_i64.into_iter().map(|x| x as usize));
118+
}
119+
ScalarType::I32 => {
120+
let indices_array = indices_ds.read_array_slice::<i32, _, ndarray::Ix1>(&[SelectInfoElem::from(offset..chunk_end)])?;
121+
let (indices_i32, _) = indices_array.into_raw_vec_and_offset();
122+
indices.extend(indices_i32.into_iter().map(|x| x as usize));
123+
}
124+
_ => anyhow::bail!("Unsupported index type for CSR matrix"),
125+
}
126+
127+
data.extend(data_chunk);
128+
129+
offset = chunk_end;
130+
131+
if config.show_progress && nnz > 10_000_000 {
132+
let progress = (offset as f64 / nnz as f64 * 100.0) as usize;
133+
if progress >= last_progress + 10 {
134+
println!("Loading CSR matrix: {}%", progress);
135+
last_progress = progress;
136+
}
137+
}
138+
}
139+
140+
if config.show_progress && nnz > 10_000_000 {
141+
println!("Constructing CSR matrix structure...");
142+
}
143+
144+
let pattern = unsafe {
145+
SparsityPattern::from_offset_and_indices_unchecked(nrows, ncols, indptr, indices)
146+
};
147+
let csr = CsrMatrix::try_from_pattern_and_values(pattern, data).map_err(|e| anyhow::anyhow!("There was an error constructing the matrix {}", e))?;
148+
149+
Ok(ArrayData::from(csr))
150+
}

src/converter.rs

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use std::ops::{Deref, DerefMut};
33
use anndata::data::DataFrameIndex;
44
use anndata::{
55
AnnData, AnnDataOp, ArrayData, ArrayElemOp, AxisArrays, Backend, ElemCollection,
6-
ElemCollectionOp,
76
};
87
use anndata_hdf5::H5;
98
use anyhow::Ok;
@@ -13,6 +12,25 @@ use crate::{
1312
IMAnnData, IMArrayElement, IMElementCollection,
1413
};
1514

15+
#[derive(Clone, Debug)]
16+
pub struct LoadingConfig {
17+
pub use_chunked_loading: bool,
18+
pub chunk_size_mb: usize,
19+
pub memory_threshold_mb: usize,
20+
pub show_progress: bool,
21+
}
22+
23+
impl Default for LoadingConfig {
24+
fn default() -> Self {
25+
Self {
26+
use_chunked_loading: false,
27+
chunk_size_mb: 100,
28+
memory_threshold_mb: 1024,
29+
show_progress: true,
30+
}
31+
}
32+
}
33+
1634
pub fn convert_to_in_memory<B: Backend>(anndata: AnnData<B>) -> anyhow::Result<IMAnnData> {
1735
let obs_df = anndata.read_obs()?;
1836
let obs_names = anndata.obs_names();

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ mod ad;
22
mod base;
33
mod converter;
44
pub(crate) mod utils;
5+
pub(crate) mod chunked_loader;
6+
pub(crate) mod optimized_loader;
57

68
pub use ad::IMAnnData;
79
pub use ad::helpers::IMArrayElement;

src/optimized_loader.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
pub fn load_csr_optimized<B: Backend>(
2+
container: &DataContainer<B>,
3+
) -> anyhow::Result<ArrayData> {
4+
let group = container.as_group()?;
5+
let shape: Vec<u64> = group.get_attr("shape")?;
6+
let nrows = shape[0] as usize;
7+
let ncols = shape[1] as usize;
8+
9+
let data_ds = group.open_dataset("data")?;
10+
let indices_ds = group.open_dataset("indices")?;
11+
let indptr_ds = group.open_dataset("indptr")?;
12+
13+
// Use your existing function but optimize it to avoid iterator when possible
14+
let indptr = read_array_as_usize_optimized::<B>(&indptr_ds)?;
15+
let indices = read_array_as_usize_optimized::<B>(&indices_ds)?;
16+
17+
// Read data based on type - optimize to avoid copying when possible
18+
match data_ds.dtype()? {
19+
ScalarType::F64 => {
20+
let arr = data_ds.read_array::<f64, Ix1>()?;
21+
let (data, offset) = arr.into_raw_vec_and_offset();
22+
if offset.is_none() {
23+
build_csr_matrix(nrows, ncols, indptr, indices, data)
24+
} else {
25+
build_csr_matrix(nrows, ncols, indptr, indices, arr.to_vec())
26+
}
27+
}
28+
ScalarType::F32 => {
29+
let arr = data_ds.read_array::<f32, Ix1>()?;
30+
let (data, offset) = arr.into_raw_vec_and_offset();
31+
if offset.is_none() {
32+
build_csr_matrix(nrows, ncols, indptr, indices, data)
33+
} else {
34+
build_csr_matrix(nrows, ncols, indptr, indices, arr.to_vec())
35+
}
36+
}
37+
ScalarType::I64 => {
38+
let arr = data_ds.read_array::<i64, Ix1>()?;
39+
let (data, offset) = arr.into_raw_vec_and_offset();
40+
if offset.is_none() {
41+
build_csr_matrix(nrows, ncols, indptr, indices, data)
42+
} else {
43+
build_csr_matrix(nrows, ncols, indptr, indices, arr.to_vec())
44+
}
45+
}
46+
ScalarType::I32 => {
47+
let arr = data_ds.read_array::<i32, Ix1>()?;
48+
let (data, offset) = arr.into_raw_vec_and_offset();
49+
if offset.is_none() {
50+
build_csr_matrix(nrows, ncols, indptr, indices, data)
51+
} else {
52+
build_csr_matrix(nrows, ncols, indptr, indices, arr.to_vec())
53+
}
54+
}
55+
_ => {
56+
// Fallback to standard loading for other types
57+
anndata::data::ArrayData::read(container)
58+
}
59+
}
60+
}

0 commit comments

Comments
 (0)