@@ -4,74 +4,45 @@ use anndata::{
44 ArrayElemOp ,
55} ;
66use nalgebra_sparse:: { pattern:: SparsityPattern , CsrMatrix } ;
7+ use ndarray:: Ix1 ;
78
8-
9-
10-
11-
12-
9+ use crate :: { converter:: LoadingConfig , utils:: { read_array_as_usize_optimized, read_array_slice_as_usize} } ;
1310
1411pub fn load_csr_chunked < B : Backend > (
1512 container : & DataContainer < B > ,
1613 config : & LoadingConfig ,
1714) -> anyhow:: Result < ArrayData > {
1815 let group = container. as_group ( ) ?;
1916 let shape: Vec < u64 > = group. get_attr ( "shape" ) ?;
20- let nrows = shape[ 0 ] as usize ;
21- let ncols = shape[ 1 ] as usize ;
17+ let ( nrows, ncols) = ( shape[ 0 ] as usize , shape[ 1 ] as usize ) ;
2218
2319 let data_ds = group. open_dataset ( "data" ) ?;
2420 let indices_ds = group. open_dataset ( "indices" ) ?;
2521 let indptr_ds = group. open_dataset ( "indptr" ) ?;
2622
27- // Use the helper function to read indptr
28- let indptr = read_array_as_usize :: < B > ( & indptr_ds) ?;
29-
23+ let indptr = read_array_as_usize_optimized :: < B > ( & indptr_ds) ?;
3024 let nnz = data_ds. shape ( ) [ 0 ] ;
3125
3226 if config. show_progress && nnz > 10_000_000 {
33- println ! (
34- "Loading CSR matrix: {} rows, {} cols, {} non-zeros" ,
35- nrows, ncols, nnz
36- ) ;
27+ println ! ( "Loading CSR matrix: {} rows, {} cols, {} non-zeros" , nrows, ncols, nnz) ;
3728 }
3829
30+ use ScalarType :: * ;
3931 match data_ds. dtype ( ) ? {
40- ScalarType :: F64 => load_csr_typed :: < B , f64 > (
41- nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config,
42- ) ,
43- ScalarType :: F32 => load_csr_typed :: < B , f32 > (
44- nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config,
45- ) ,
46- ScalarType :: I64 => load_csr_typed :: < B , i64 > (
47- nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config,
48- ) ,
49- ScalarType :: I32 => load_csr_typed :: < B , i32 > (
50- nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config,
51- ) ,
52- ScalarType :: I16 => load_csr_typed :: < B , i16 > (
53- nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config,
54- ) ,
55- ScalarType :: I8 => load_csr_typed :: < B , i8 > (
56- nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config,
57- ) ,
58- ScalarType :: U64 => load_csr_typed :: < B , u64 > (
59- nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config,
60- ) ,
61- ScalarType :: U32 => load_csr_typed :: < B , u32 > (
62- nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config,
63- ) ,
64- ScalarType :: U16 => load_csr_typed :: < B , u16 > (
65- nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config,
66- ) ,
67- ScalarType :: U8 => load_csr_typed :: < B , u8 > (
68- nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config,
69- ) ,
32+ F64 => load_csr_typed :: < B , f64 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
33+ F32 => load_csr_typed :: < B , f32 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
34+ I64 => load_csr_typed :: < B , i64 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
35+ I32 => load_csr_typed :: < B , i32 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
36+ I16 => load_csr_typed :: < B , i16 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
37+ I8 => load_csr_typed :: < B , i8 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
38+ U64 => load_csr_typed :: < B , u64 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
39+ U32 => load_csr_typed :: < B , u32 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
40+ U16 => load_csr_typed :: < B , u16 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
41+ U8 => load_csr_typed :: < B , u8 > ( nrows, ncols, nnz, indptr, & data_ds, & indices_ds, config) ,
7042 dt => anyhow:: bail!( "Unsupported data type for CSR matrix: {:?}" , dt) ,
7143 }
7244}
7345
74-
7546fn load_csr_typed < B : Backend , T : anndata:: backend:: BackendData > (
7647 nrows : usize ,
7748 ncols : usize ,
@@ -82,69 +53,46 @@ fn load_csr_typed<B: Backend, T: anndata::backend::BackendData>(
8253 config : & LoadingConfig ,
8354) -> anyhow:: Result < ArrayData >
8455where
85- anndata :: ArrayData : std :: convert :: From < nalgebra_sparse :: CsrMatrix < T > >
56+ ArrayData : From < CsrMatrix < T > > ,
8657{
87- let chunk_size = ( config. chunk_size_mb * 1_048_576 ) / ( std:: mem:: size_of :: < T > ( ) + 8 ) ;
88- let chunk_size = chunk_size. max ( 1000 ) ;
58+ let chunk_size = ( ( config. chunk_size_mb << 20 ) / ( std:: mem:: size_of :: < T > ( ) + 8 ) ) . max ( 1000 ) ;
8959
9060 let mut data = Vec :: with_capacity ( nnz) ;
9161 let mut indices = Vec :: with_capacity ( nnz) ;
9262
63+ let show_progress = config. show_progress && nnz > 10_000_000 ;
64+ let progress_interval = if show_progress { nnz / 10 } else { usize:: MAX } ;
65+ let mut next_progress = progress_interval;
66+
9367 let mut offset = 0 ;
94- let mut last_progress = 0 ;
95-
9668 while offset < nnz {
9769 let chunk_end = ( offset + chunk_size) . min ( nnz) ;
98-
99- let data_array = data_ds. read_array_slice :: < T , _ , ndarray:: Ix1 > ( & [ SelectInfoElem :: from ( offset..chunk_end) ] ) ?;
100- let data_chunk: Vec < T > = data_array. into_raw_vec ( ) ;
101-
102-
103- match indices_ds. dtype ( ) ? {
104- ScalarType :: U64 => {
105- let indices_array = indices_ds. read_array_slice :: < u64 , _ , ndarray:: Ix1 > ( & [ SelectInfoElem :: from ( offset..chunk_end) ] ) ?;
106- let ( indices_u64, _) = indices_array. into_raw_vec_and_offset ( ) ;
107- indices. extend ( indices_u64. into_iter ( ) . map ( |x| x as usize ) ) ;
108- }
109- ScalarType :: U32 => {
110- let indices_array = indices_ds. read_array_slice :: < u32 , _ , ndarray:: Ix1 > ( & [ SelectInfoElem :: from ( offset..chunk_end) ] ) ?;
111- let ( indices_u32, _) = indices_array. into_raw_vec_and_offset ( ) ;
112- indices. extend ( indices_u32. into_iter ( ) . map ( |x| x as usize ) ) ;
113- }
114- ScalarType :: I64 => {
115- let indices_array = indices_ds. read_array_slice :: < i64 , _ , ndarray:: Ix1 > ( & [ SelectInfoElem :: from ( offset..chunk_end) ] ) ?;
116- let ( indices_i64, _) = indices_array. into_raw_vec_and_offset ( ) ;
117- indices. extend ( indices_i64. into_iter ( ) . map ( |x| x as usize ) ) ;
118- }
119- ScalarType :: I32 => {
120- let indices_array = indices_ds. read_array_slice :: < i32 , _ , ndarray:: Ix1 > ( & [ SelectInfoElem :: from ( offset..chunk_end) ] ) ?;
121- let ( indices_i32, _) = indices_array. into_raw_vec_and_offset ( ) ;
122- indices. extend ( indices_i32. into_iter ( ) . map ( |x| x as usize ) ) ;
123- }
124- _ => anyhow:: bail!( "Unsupported index type for CSR matrix" ) ,
70+ let range = [ SelectInfoElem :: from ( offset..chunk_end) ] ;
71+ let data_array = data_ds. read_array_slice :: < T , _ , Ix1 > ( & range) ?;
72+ let ( data_vec, data_offset) = data_array. into_raw_vec_and_offset ( ) ;
73+ if data_offset. is_none ( ) {
74+ data. extend ( data_vec) ;
75+ } else {
76+ data. extend ( data_vec) ;
12577 }
12678
127- data. extend ( data_chunk) ;
79+ let indices_chunk = read_array_slice_as_usize :: < B > ( indices_ds, & range) ?;
80+ indices. extend ( indices_chunk) ;
12881
12982 offset = chunk_end;
13083
131- if config. show_progress && nnz > 10_000_000 {
132- let progress = ( offset as f64 / nnz as f64 * 100.0 ) as usize ;
133- if progress >= last_progress + 10 {
134- println ! ( "Loading CSR matrix: {}%" , progress) ;
135- last_progress = progress;
136- }
84+ if show_progress && offset >= next_progress {
85+ println ! ( "Loading CSR matrix: {}%" , offset * 100 / nnz) ;
86+ next_progress += progress_interval;
13787 }
13888 }
13989
140- if config . show_progress && nnz > 10_000_000 {
90+ if show_progress {
14191 println ! ( "Constructing CSR matrix structure..." ) ;
14292 }
14393
144- let pattern = unsafe {
145- SparsityPattern :: from_offset_and_indices_unchecked ( nrows, ncols, indptr, indices)
146- } ;
147- let csr = CsrMatrix :: try_from_pattern_and_values ( pattern, data) . map_err ( |e| anyhow:: anyhow!( "There was an error constructing the matrix {}" , e) ) ?;
148-
149- Ok ( ArrayData :: from ( csr) )
150- }
94+ let pattern = unsafe { SparsityPattern :: from_offset_and_indices_unchecked ( nrows, ncols, indptr, indices) } ;
95+ CsrMatrix :: try_from_pattern_and_values ( pattern, data)
96+ . map ( ArrayData :: from)
97+ . map_err ( |e| anyhow:: anyhow!( "Failed to construct CSR matrix: {}" , e) )
98+ }
0 commit comments