Skip to content

Commit 68b957a

Browse files
Add PSI runnable. (#86)
* Add PSI runnable. The input stats table schema is the same with the output stats table of binning runnable. * Remove unnecessary print.
1 parent 8530edc commit 68b957a

2 files changed

Lines changed: 96 additions & 0 deletions

File tree

runnables/binning/psi.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
5+
def calc_psi_per_bin(
6+
expected_prob,
7+
actual_prob):
8+
FALLBACK_VALUE = 0.001
9+
expected_prob = FALLBACK_VALUE if expected_prob == 0.0 else expected_prob
10+
actual_prob = FALLBACK_VALUE if actual_prob == 0.0 else actual_prob
11+
12+
return (expected_prob - actual_prob) * np.log(expected_prob * 1.0 / actual_prob)
13+
14+
15+
def calc_psi(
16+
expected_bin_probs,
17+
actual_bin_probs):
18+
assert(len(expected_bin_probs) == len(actual_bin_probs))
19+
20+
result = 0.0
21+
for i in range(len(expected_bin_probs)):
22+
result += calc_psi_per_bin(expected_bin_probs[i], actual_bin_probs[i])
23+
24+
return result
25+
26+
27+
def get_cols_bin_probs(
28+
stats_df,
29+
bin_prob_column_name):
30+
col_bin_probs = {}
31+
for _, row in stats_df.iterrows():
32+
col_name = row['name']
33+
bin_probs = [float(item) for item in row[bin_prob_column_name].split(',')]
34+
col_bin_probs[col_name] = bin_probs
35+
36+
return col_bin_probs

runnables/psi.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import argparse
2+
import os
3+
import pandas as pd
4+
from binning.psi import calc_psi, get_cols_bin_probs
5+
from run_io.db_adapter import convertDSNToRfc1738
6+
from sqlalchemy import create_engine
7+
8+
9+
def build_argument_parser():
10+
parser = argparse.ArgumentParser(allow_abbrev=False)
11+
parser.add_argument("--dbname", type=str, required=True)
12+
parser.add_argument("--refer_stats_table", type=str, required=True)
13+
parser.add_argument("--bin_prob_column", type=str, default="bin_prob")
14+
15+
return parser
16+
17+
18+
if __name__ == "__main__":
19+
parser = build_argument_parser()
20+
args, _ = parser.parse_known_args()
21+
22+
select_input = os.getenv("SQLFLOW_TO_RUN_SELECT")
23+
output = os.getenv("SQLFLOW_TO_RUN_INTO")
24+
datasource = os.getenv("SQLFLOW_DATASOURCE")
25+
26+
url = convertDSNToRfc1738(datasource, args.dbname)
27+
engine = create_engine(url)
28+
29+
input_df = pd.read_sql(
30+
sql=select_input,
31+
con=engine)
32+
refer_stats_df = pd.read_sql_table(
33+
table_name=args.refer_stats_table,
34+
con=engine)
35+
36+
actual_cols_bin_probs = get_cols_bin_probs(input_df, args.bin_prob_column)
37+
expected_cols_bin_probs = get_cols_bin_probs(input_df, args.bin_prob_column)
38+
39+
common_column_names = set.intersection(
40+
set(actual_cols_bin_probs.keys()),
41+
set(expected_cols_bin_probs.keys()))
42+
43+
print("Calculate the PSI value for {} fields.".format(len(common_column_names)))
44+
cols_psi_data = []
45+
for column_name in common_column_names:
46+
psi_value = calc_psi(actual_cols_bin_probs[column_name], expected_cols_bin_probs[column_name])
47+
cols_psi_data.append(
48+
{
49+
"name": column_name,
50+
"psi": psi_value
51+
}
52+
)
53+
cols_psi_df = pd.DataFrame(cols_psi_data)
54+
55+
print("Persist the PSI result into the table {}".format(output))
56+
cols_psi_df.to_sql(
57+
name=output,
58+
con=engine,
59+
index=False
60+
)

0 commit comments

Comments
 (0)