@@ -11,8 +11,8 @@ def build_argument_parser():
1111 parser = argparse .ArgumentParser (allow_abbrev = False )
1212 parser .add_argument ("--dbname" , type = str , required = True )
1313 parser .add_argument ("--columns" , type = str , required = True )
14- parser .add_argument ("--bin_methods " , type = str , required = False )
15- parser .add_argument ("--bin_nums " , type = str , required = False )
14+ parser .add_argument ("--bin_method " , type = str , required = False )
15+ parser .add_argument ("--bin_num " , type = str , required = False )
1616 parser .add_argument ("--bin_input_table" , type = str , required = False )
1717 parser .add_argument ("--reverse_cumsum" , type = bool , default = False )
1818 parser .add_argument ("--two_dim_bin_cols" , type = str , required = False )
@@ -24,8 +24,8 @@ def build_argument_parser():
2424 parser = build_argument_parser ()
2525 args , _ = parser .parse_known_args ()
2626 columns = args .columns .split (',' )
27- bin_methods = args .bin_methods .split (',' ) if args .bin_methods else None
28- bin_nums = [int (item ) for item in args .bin_nums .split (',' )] if args .bin_nums else None
27+ bin_method_array = args .bin_method .split (',' ) if args .bin_method else None
28+ bin_num_array = [int (item ) for item in args .bin_num .split (',' )] if args .bin_num else None
2929 two_dim_bin_cols = args .two_dim_bin_cols .split (',' ) if args .two_dim_bin_cols else None
3030
3131 select_input = os .getenv ("SQLFLOW_TO_RUN_SELECT" )
@@ -57,16 +57,26 @@ def build_argument_parser():
5757 raise ValueError ("The provided bin boundaries contains keys: {}. But they cannot cover all the \
5858 input columns: {}" .format (cols_bin_boundaries .keys (), columns ))
5959
60- print ("Ignore the bin_nums and bin_methods arguments" )
61- bin_nums = [None for i in range (len (columns ))]
62- bin_methods = [None for i in range (len (columns ))]
60+ print ("Ignore the bin_num and bin_method arguments" )
61+ bin_num_array = [None ] * len (columns )
62+ bin_method_array = [None ] * len (columns )
63+ else :
64+ if len (bin_num_array ) == 1 :
65+ bin_num_array = bin_num_array * len (columns )
66+ else :
67+ assert (len (bin_num_array ) == len (columns ))
68+
69+ if len (bin_method_array ) == 1 :
70+ bin_method_array = bin_method_array * len (columns )
71+ else :
72+ assert (len (bin_method_array ) == len (columns ))
6373
6474 print ("Calculate the statistics result for columns: {}" .format (columns ))
6575 stats_df = calc_stats (
6676 input_md ,
6777 columns ,
68- bin_methods ,
69- bin_nums ,
78+ bin_method_array ,
79+ bin_num_array ,
7080 cols_bin_boundaries ,
7181 args .reverse_cumsum )
7282
@@ -83,10 +93,10 @@ def build_argument_parser():
8393 input_md ,
8494 columns [0 ],
8595 columns [1 ],
86- bin_methods [0 ],
87- bin_methods [1 ],
88- bin_nums [0 ],
89- bin_nums [1 ],
96+ bin_method_array [0 ],
97+ bin_method_array [1 ],
98+ bin_num_array [0 ],
99+ bin_num_array [1 ],
90100 cols_bin_boundaries .get (columns [0 ], None ),
91101 cols_bin_boundaries .get (columns [1 ], None ),
92102 args .reverse_cumsum )
0 commit comments