Skip to content

Commit 1d5535a

Browse files
Add data lag operation in the Runnable extract_ts_features (#84)
* Move db utils method to db_adapter.py * Adjust the import order * Add TODO comments of IO adapter in sqlflow runtime library. * Add __init__ in run_io folder * Add lag columns before rolling time series data and extract features. * Rename Rfc1783 to Rfc1738
1 parent b80970a commit 1d5535a

3 files changed

Lines changed: 24 additions & 6 deletions

File tree

runnables/extract_ts_features.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import argparse
22
import os
33
import pandas as pd
4-
from run_io.db_adapter import convertDSNToRfc1783
4+
from run_io.db_adapter import convertDSNToRfc1738
55
from sqlalchemy import create_engine
6-
from time_series_processing.ts_feature_extractor import add_features_extracted_from_ts_data
6+
from time_series_processing.ts_feature_extractor import add_features_extracted_from_ts_data, add_lag_columns
77

88

99
def build_argument_parser():
@@ -12,6 +12,7 @@ def build_argument_parser():
1212
parser.add_argument("--column_id", type=str, required=True)
1313
parser.add_argument("--column_time", type=str, required=True)
1414
parser.add_argument("--columns_value", type=str, required=True)
15+
parser.add_argument("--lag_num", type=int, default=1)
1516
parser.add_argument("--windows", type=str, required=True)
1617
parser.add_argument("--min_window", type=str, default=0)
1718
parser.add_argument("--extract_setting", type=str, default="minimal", choices=["minimal", "efficient", "comprehensive"])
@@ -29,23 +30,27 @@ def build_argument_parser():
2930
output = os.getenv("SQLFLOW_TO_RUN_INTO")
3031
datasource = os.getenv("SQLFLOW_DATASOURCE")
3132

32-
url = convertDSNToRfc1783(datasource, args.dbname)
33+
url = convertDSNToRfc1738(datasource, args.dbname)
3334
engine = create_engine(url)
3435
input = pd.read_sql(
3536
sql=select_input,
3637
con=engine)
3738

39+
df_with_lag_columns, lag_column_names = add_lag_columns(input, columns_value, args.lag_num)
40+
3841
print("Start extracting the features from the time series data.")
3942
df_with_extracted_features = add_features_extracted_from_ts_data(
40-
input,
43+
df_with_lag_columns,
4144
column_id=args.column_id,
4245
column_time=args.column_time,
43-
columns_value=columns_value,
46+
columns_value=lag_column_names,
4447
windows=windows,
4548
min_window=args.min_window,
4649
extract_setting=args.extract_setting)
4750
print("Complete the feature extraction.")
4851

52+
df_with_extracted_features = df_with_extracted_features.drop(columns=lag_column_names)
53+
4954
df_with_extracted_features.to_sql(
5055
name=output,
5156
con=engine,

runnables/run_io/db_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def parseMySQLDSN(dsn):
1515
# TODO(brightcoder01): Should we put this kind of common method
1616
# in sqlflow runtime? While writing the runnable code, users can
1717
# import the runtime library.
18-
def convertDSNToRfc1783(driver_dsn, defaultDbName):
18+
def convertDSNToRfc1738(driver_dsn, defaultDbName):
1919
driver, dsn = driver_dsn.split("://")
2020
user, passwd, host, port, database, config = parseMySQLDSN(dsn)
2121

runnables/time_series_processing/ts_feature_extractor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ def _roll_ts_and_extract_features(
5252
return extracted_features
5353

5454

55+
def add_lag_columns(
56+
input,
57+
columns_value,
58+
lag_num):
59+
lag_column_names = []
60+
for column_value in columns_value:
61+
lag_column_name = "{}_lag_{}".format(column_value, lag_num)
62+
input[lag_column_name] = input[column_value].shift(lag_num)
63+
lag_column_names.append(lag_column_name)
64+
65+
return input[lag_num:], lag_column_names
66+
67+
5568
def add_features_extracted_from_ts_data(
5669
input,
5770
column_id,

0 commit comments

Comments
 (0)