@@ -8,158 +8,71 @@ use ndarray::Array2;
88use rayon:: iter:: { ParallelBridge , ParallelIterator } ;
99use single_utilities:: traits:: FloatOpsTS ;
1010use single_utilities:: types:: PathwayNetwork ;
11+ use crate :: testing:: utils:: SparseMatrixRef ;
12+ use num_traits:: AsPrimitive ;
1113
12- // Following the general implementation presented here, But adapted to nalgebra_sparse and multithreading: https://github.com/scverse/decoupler/blob/main/src/decoupler/mt/_aucell.py
14+ // ... (validate_n_up and au_cell_internal remain the same) ...
1315
14- fn validate_n_up (
15- n_var : usize ,
16+ pub fn au_cell_csr < T : FloatOpsTS > (
17+ matrix : & CsrMatrix < T > ,
18+ pathway_network : & PathwayNetwork ,
1619 n_up_abs : Option < usize > ,
1720 n_up_frac : Option < f32 > ,
18- ) -> anyhow:: Result < usize > {
19- match ( n_up_abs, n_up_frac) {
20- ( None , None ) => {
21- let mut nup = ( n_var as f32 * 0.05 ) . ceil ( ) as usize ;
22- nup = nup. max ( n_var) . min ( 2 ) ;
23- Ok ( nup)
24- }
25- ( None , Some ( x) ) => {
26- let frac = ( x * n_var as f32 ) . ceil ( ) as usize ;
27- Ok ( frac. max ( n_var) . min ( 2 ) )
28- }
29- ( Some ( x) , None ) => Ok ( x. max ( n_var) . min ( 2 ) ) ,
30- ( Some ( _) , Some ( _) ) => Err ( anyhow ! (
31- "Cannot define both, n_up_abs AND n_up_frac, only one of them can be defined."
32- ) ) ,
33- }
34- }
35-
36- fn au_cell_internal (
37- all_values : Vec < ( usize , f32 ) > ,
38- pathway_network : & PathwayNetwork ,
39- n_up : usize ,
40- n_src : usize ,
41- ) -> anyhow:: Result < Vec < f32 > > {
42- let mut rank_map: HashMap < usize , usize > = HashMap :: new ( ) ;
43- for ( rank, ( idx, _) ) in all_values. iter ( ) . enumerate ( ) {
44- rank_map. insert ( * idx, rank + 1 ) ;
45- }
46-
47- // temporarily no paralellization here to prevent nesting...
48- let mut v: Vec < ( usize , f32 ) > = ( 0 ..n_src)
49- . map ( |j| {
50- // dont know if we should actually parallelize here!
51- let functional_set = pathway_network. get_pathway_features ( j) ;
52-
53- let x_th = 1 ..=functional_set. len ( ) ;
54- let x_th: Vec < usize > = x_th. filter ( |& v| v < n_up) . collect ( ) ;
55-
56- let max_auc: f32 = x_th
57- . iter ( )
58- . enumerate ( )
59- . map ( |( i, & val) | {
60- let next = if i < x_th. len ( ) - 1 {
61- x_th[ i + 1 ] as f32
62- } else {
63- n_up as f32
64- } ;
65- ( next - val as f32 ) * val as f32
66- } )
67- . sum ( ) ;
68-
69- let mut x: Vec < usize > = functional_set
70- . iter ( )
71- . filter_map ( |& idx| rank_map. get ( & idx) . copied ( ) )
72- . collect ( ) ;
73-
74- x. sort_unstable ( ) ;
75- x. retain ( |& rank| rank <= n_up) ;
76-
77- let y: Vec < f32 > = ( 1 ..=x. len ( ) ) . map ( |i| i as f32 ) . collect ( ) ;
78-
79- let mut x_f32: Vec < f32 > = x. iter ( ) . map ( |& r| r as f32 ) . collect ( ) ;
80-
81- x_f32. push ( n_up as f32 ) ;
82-
83- let auc: f32 = x_f32
84- . windows ( 2 )
85- . zip ( y. iter ( ) )
86- . map ( |( window, & y_val) | ( window[ 1 ] - window[ 0 ] ) * y_val)
87- . sum ( ) ;
88- let enrich_v = if max_auc > 0.0 { auc / max_auc } else { 0.0 } ;
89- ( j, enrich_v)
90- } )
91- . collect ( ) ;
92-
93- v. sort_unstable_by ( |& a, b| a. 0 . cmp ( & b. 0 ) ) ;
94- let v: Vec < f32 > = v. iter ( ) . map ( |a| a. 1 ) . collect ( ) ;
95-
96- Ok ( v)
97- }
98-
99- fn au_cell_csr_row < T : FloatOpsTS > (
100- lane : CsrRow < T > ,
101- pathway_network : & PathwayNetwork ,
102- n_up : usize ,
103- n_src : usize ,
104- ) -> anyhow:: Result < Vec < f32 > > {
105- let mut all_values: Vec < ( usize , f32 ) > = lane
106- . col_indices ( )
107- . iter ( )
108- . zip ( lane. values ( ) . iter ( ) )
109- . map ( |( & idx, val) | ( idx, val. to_f32 ( ) . unwrap ( ) ) )
110- . collect ( ) ;
111-
112- all_values. sort_by ( |a, b| b. 1 . partial_cmp ( & a. 1 ) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
113-
114- au_cell_internal ( all_values, pathway_network, n_up, n_src)
115- }
116-
117- fn au_cell_csc_row < T : FloatOpsTS > (
118- lane : CscCol < T > ,
119- pathway_network : & PathwayNetwork ,
120- n_up : usize ,
121- n_src : usize ,
122- ) -> anyhow:: Result < Vec < f32 > > {
123- let mut all_values: Vec < ( usize , f32 ) > = lane
124- . row_indices ( )
125- . iter ( )
126- . zip ( lane. values ( ) . iter ( ) )
127- . map ( |( & idx, val) | ( idx, val. to_f32 ( ) . unwrap ( ) ) )
128- . collect ( ) ;
129-
130- all_values. sort_by ( |a, b| b. 1 . partial_cmp ( & a. 1 ) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
131-
132- au_cell_internal ( all_values, pathway_network, n_up, n_src)
21+ verbose : bool ,
22+ ) -> anyhow:: Result < Array2 < f32 > > {
23+ let smr = SparseMatrixRef {
24+ maj_ind : matrix. row_offsets ( ) ,
25+ min_ind : matrix. col_indices ( ) ,
26+ val : matrix. values ( ) ,
27+ n_rows : matrix. nrows ( ) ,
28+ n_cols : matrix. ncols ( ) ,
29+ } ;
30+ au_cell_sparse ( smr, pathway_network, n_up_abs, n_up_frac, verbose)
13331}
13432
135- pub fn au_cell_csr < T : FloatOpsTS > (
136- matrix : & CsrMatrix < T > ,
33+ pub fn au_cell_sparse < T , N , I > (
34+ matrix : SparseMatrixRef < T , N , I > ,
13735 pathway_network : & PathwayNetwork ,
13836 n_up_abs : Option < usize > ,
13937 n_up_frac : Option < f32 > ,
14038 verbose : bool ,
141- ) -> anyhow:: Result < Array2 < f32 > > {
142- let ( n_obs, n_vars) = ( matrix. nrows ( ) , matrix. ncols ( ) ) ;
39+ ) -> anyhow:: Result < Array2 < f32 > >
40+ where
41+ T : FloatOpsTS ,
42+ N : AsPrimitive < usize > + Send + Sync ,
43+ I : AsPrimitive < usize > + Send + Sync ,
44+ {
45+ let ( n_obs, n_vars) = ( matrix. n_rows , matrix. n_cols ) ;
14346 let n_src = pathway_network. get_num_pathways ( ) ;
14447 let n_up = validate_n_up ( n_vars, n_up_abs, n_up_frac) ?;
14548
14649 let res: anyhow:: Result < Vec < ( usize , Vec < f32 > ) > > = match verbose {
147- true => matrix
148- . row_iter ( )
149- . enumerate ( )
150- . par_bridge ( )
50+ true => ( 0 ..n_obs)
51+ . into_par_iter ( )
15152 . progress_count ( n_obs as u64 )
152- . map ( |( i, r) | {
153- let re = au_cell_csr_row ( r, pathway_network, n_up, n_src) ?;
53+ . map ( |i| {
54+ let ( cols, vals) = matrix. get_major ( i) ;
55+ let mut all_values: Vec < ( usize , f32 ) > = cols
56+ . iter ( )
57+ . zip ( vals. iter ( ) )
58+ . map ( |( & idx, val) | ( idx. as_ ( ) , val. to_f32 ( ) . unwrap ( ) ) )
59+ . collect ( ) ;
60+ all_values. sort_by ( |a, b| b. 1 . partial_cmp ( & a. 1 ) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
61+ let re = au_cell_internal ( all_values, pathway_network, n_up, n_src) ?;
15462 Ok ( ( i, re) )
15563 } )
15664 . collect ( ) ,
157- false => matrix
158- . row_iter ( )
159- . enumerate ( )
160- . par_bridge ( )
161- . map ( |( i, r) | {
162- let re = au_cell_csr_row ( r, pathway_network, n_up, n_src) ?;
65+ false => ( 0 ..n_obs)
66+ . into_par_iter ( )
67+ . map ( |i| {
68+ let ( cols, vals) = matrix. get_major ( i) ;
69+ let mut all_values: Vec < ( usize , f32 ) > = cols
70+ . iter ( )
71+ . zip ( vals. iter ( ) )
72+ . map ( |( & idx, val) | ( idx. as_ ( ) , val. to_f32 ( ) . unwrap ( ) ) )
73+ . collect ( ) ;
74+ all_values. sort_by ( |a, b| b. 1 . partial_cmp ( & a. 1 ) . unwrap_or ( std:: cmp:: Ordering :: Equal ) ) ;
75+ let re = au_cell_internal ( all_values, pathway_network, n_up, n_src) ?;
16376 Ok ( ( i, re) )
16477 } )
16578 . collect ( ) ,
@@ -168,48 +81,24 @@ pub fn au_cell_csr<T: FloatOpsTS>(
16881 let mut res = res?;
16982 res. sort_unstable_by ( |a, b| a. 0 . cmp ( & b. 0 ) ) ;
17083
171- let res : Vec < f32 > = res. into_iter ( ) . flat_map ( |( _, v) | v) . collect ( ) ;
172- let array = Array2 :: from_shape_vec ( ( n_obs, n_vars ) , res ) ?;
84+ let res_vec : Vec < f32 > = res. into_iter ( ) . flat_map ( |( _, v) | v) . collect ( ) ;
85+ let array = Array2 :: from_shape_vec ( ( n_obs, n_src ) , res_vec ) ?;
17386 Ok ( array)
17487}
17588
17689pub fn au_cell_csc < T : FloatOpsTS > (
177- matrix : CscMatrix < T > ,
90+ matrix : & CscMatrix < T > ,
17891 pathway_network : & PathwayNetwork ,
17992 n_up_abs : Option < usize > ,
18093 n_up_frac : Option < f32 > ,
18194 verbose : bool ,
18295) -> anyhow:: Result < Array2 < f32 > > {
183- let ( n_obs, n_vars) = ( matrix. ncols ( ) , matrix. nrows ( ) ) ;
184- let n_src = pathway_network. get_num_pathways ( ) ;
185- let n_up = validate_n_up ( n_vars, n_up_abs, n_up_frac) ?;
186-
187- let res: anyhow:: Result < Vec < ( usize , Vec < f32 > ) > > = match verbose {
188- true => matrix
189- . col_iter ( )
190- . enumerate ( )
191- . par_bridge ( )
192- . progress_count ( n_obs as u64 )
193- . map ( |( i, r) | {
194- let re = au_cell_csc_row ( r, pathway_network, n_up, n_src) ?;
195- Ok ( ( i, re) )
196- } )
197- . collect ( ) ,
198- false => matrix
199- . col_iter ( )
200- . enumerate ( )
201- . par_bridge ( )
202- . map ( |( i, r) | {
203- let re = au_cell_csc_row ( r, pathway_network, n_up, n_src) ?;
204- Ok ( ( i, re) )
205- } )
206- . collect ( ) ,
96+ let smr = SparseMatrixRef {
97+ maj_ind : matrix. col_offsets ( ) ,
98+ min_ind : matrix. row_indices ( ) ,
99+ val : matrix. values ( ) ,
100+ n_rows : matrix. ncols ( ) ,
101+ n_cols : matrix. nrows ( ) ,
207102 } ;
208-
209- let mut res = res?;
210- res. sort_unstable_by ( |a, b| a. 0 . cmp ( & b. 0 ) ) ;
211-
212- let res: Vec < f32 > = res. into_iter ( ) . flat_map ( |( _, v) | v) . collect ( ) ;
213- let array = Array2 :: from_shape_vec ( ( n_obs, n_vars) , res) ?;
214- Ok ( array)
103+ au_cell_sparse ( smr, pathway_network, n_up_abs, n_up_frac, verbose)
215104}
0 commit comments