@@ -833,9 +833,12 @@ fn build_internal_impl<T: VectorRepr + Send + Sync>(
833833 Ok ( graph)
834834}
835835
836- /// Diversity prune from full reservoir: select max_degree from l_max candidates using diversity.
837- /// Candidates already have distances from HashPrune — no recomputation needed for i→candidate.
838- /// Only computes inter-candidate distances for the occlusion check.
836+ /// Final diversity prune matching DiskANN's occlude_list algorithm:
837+ /// - Iterative alpha: starts at 1.0, increments by min(alpha, 1.2) each round
838+ /// - Accumulated occlusion factor per candidate: max(dist_to_point / dist_to_selected)
839+ /// - Resumable inner loop via last_checked positions
840+ ///
841+ /// Candidates already have distances from HashPrune — sorted by distance ascending.
839842// Called from within `build_internal_impl` which already runs inside a dedicated rayon
840843// thread pool installed by `build_internal`, so `par_iter` work executes on that pool.
841844#[ allow( clippy:: disallowed_methods) ]
@@ -848,6 +851,7 @@ fn final_prune_from_candidates<T: VectorRepr + Send + Sync>(
848851 alpha : f32 ,
849852) -> Vec < Vec < u32 > > {
850853 let dist_fn = make_dist_fn ( metric) ;
854+ let increment = alpha. min ( 1.2 ) ;
851855
852856 candidates_per_node
853857 . par_iter ( )
@@ -856,37 +860,73 @@ fn final_prune_from_candidates<T: VectorRepr + Send + Sync>(
856860 return Vec :: new ( ) ;
857861 }
858862
859- // Candidates are already sorted by distance from get_neighbors_sorted().
860- let mut selected: Vec < u32 > = Vec :: with_capacity ( max_degree) ;
863+ let nc = candidates. len ( ) ;
864+ // Two f32 buffers for on-demand T→f32 conversion. No bulk pre-load —
865+ // most candidates are pruned early, so converting all upfront is wasteful.
866+ let mut buf_sel = vec ! [ 0.0f32 ; ndims] ;
867+ let mut buf_cand = vec ! [ 0.0f32 ; ndims] ;
868+
869+ let mut occlude_factor = vec ! [ 0.0f32 ; nc] ;
870+ let mut last_checked = vec ! [ 0usize ; nc] ;
871+ let mut selected_idx: Vec < usize > = Vec :: with_capacity ( max_degree) ;
872+
873+ let mut current_alpha = 1.0f32 ;
874+ loop {
875+ for i in 0 ..nc {
876+ if selected_idx. len ( ) >= max_degree {
877+ break ;
878+ }
879+ if occlude_factor[ i] > current_alpha {
880+ continue ;
881+ }
861882
862- let mut point_sel = vec ! [ 0.0f32 ; ndims] ;
863- let mut point_cand = vec ! [ 0.0f32 ; ndims] ;
864- for & ( cand_id, cand_dist) in candidates {
865- if selected. len ( ) >= max_degree {
866- break ;
883+ let cand_id = candidates[ i] . 0 as usize ;
884+ let cand_dist = candidates[ i] . 1 ;
885+ T :: as_f32_into ( & data[ cand_id * ndims..( cand_id + 1 ) * ndims] , & mut buf_cand)
886+ . expect ( "f32 conversion" ) ;
887+
888+ let mut skip = false ;
889+ while last_checked[ i] < selected_idx. len ( ) {
890+ let sel_pos = selected_idx[ last_checked[ i] ] ;
891+ last_checked[ i] += 1 ;
892+
893+ if sel_pos >= i { continue ; }
894+
895+ let sel_id = candidates[ sel_pos] . 0 as usize ;
896+ T :: as_f32_into ( & data[ sel_id * ndims..( sel_id + 1 ) * ndims] , & mut buf_sel)
897+ . expect ( "f32 conversion" ) ;
898+ let dist_sel_cand = dist_fn. call ( & buf_sel, & buf_cand) ;
899+
900+ // Triangle inequality occlusion: ratio of distances.
901+ let ratio = if dist_sel_cand == 0.0 {
902+ f32:: MAX
903+ } else {
904+ cand_dist / dist_sel_cand
905+ } ;
906+ occlude_factor[ i] = occlude_factor[ i] . max ( ratio) ;
907+
908+ if occlude_factor[ i] > current_alpha {
909+ skip = true ;
910+ break ;
911+ }
912+ }
913+
914+ if skip || occlude_factor[ i] > current_alpha {
915+ continue ;
916+ }
917+
918+ // Accept this candidate.
919+ occlude_factor[ i] = f32:: MAX ;
920+ selected_idx. push ( i) ;
867921 }
868922
869- T :: as_f32_into (
870- & data[ cand_id as usize * ndims..( cand_id as usize + 1 ) * ndims] ,
871- & mut point_cand,
872- )
873- . expect ( "f32 conversion" ) ;
874- let is_pruned = selected. iter ( ) . any ( |& sel_id| {
875- T :: as_f32_into (
876- & data[ sel_id as usize * ndims..( sel_id as usize + 1 ) * ndims] ,
877- & mut point_sel,
878- )
879- . expect ( "f32 conversion" ) ;
880- let dist_sel_cand = dist_fn. call ( & point_sel, & point_cand) ;
881- dist_sel_cand * alpha < cand_dist
882- } ) ;
883-
884- if !is_pruned {
885- selected. push ( cand_id) ;
923+ if current_alpha >= alpha || selected_idx. len ( ) >= max_degree {
924+ break ;
886925 }
926+ current_alpha = ( current_alpha * increment) . min ( alpha) ;
887927 }
888928
889- selected
929+ selected_idx . iter ( ) . map ( | & i| candidates [ i ] . 0 ) . collect ( )
890930 } )
891931 . collect ( )
892932}
0 commit comments