Skip to content

Commit 981e2d8

Browse files
Pushing bug fix for shuffle (#381)
* Pushing bug fix for shuffle Fixing the shuffle bug for sklearn train-test split * Update ml_utils.py
1 parent 57b9586 commit 981e2d8

2 files changed

Lines changed: 2 additions & 9 deletions

File tree

medcat-v2/medcat/components/addons/meta_cat/ml_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,11 @@ def split_list_train_test(data: list, test_size: float, shuffle: bool = True
144144
Returns:
145145
tuple: The train data, and the test data.
146146
"""
147-
if shuffle:
148-
random.shuffle(data)
149-
150147
X_features = [x[:-1] for x in data]
151148
y_labels = [x[-1] for x in data]
152149

153150
X_train, X_test, y_train, y_test = train_test_split(
154-
X_features, y_labels, test_size=test_size, random_state=42)
151+
X_features, y_labels, test_size=test_size, shuffle=shuffle)
155152

156153
train_data = [x + [y] for x, y in zip(X_train, y_train)]
157154
test_data = [x + [y] for x, y in zip(X_test, y_test)]

v1/medcat/medcat/utils/meta_cat/ml_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,10 @@ def split_list_train_test(data: List, test_size: float, shuffle: bool = True) ->
132132
Returns:
133133
Tuple: The train data, and the test data.
134134
"""
135-
if shuffle:
136-
random.shuffle(data)
137-
138135
X_features = [x[:-1] for x in data]
139136
y_labels = [x[-1] for x in data]
140137

141-
X_train, X_test, y_train, y_test = train_test_split(X_features, y_labels, test_size=test_size,
142-
random_state=42)
138+
X_train, X_test, y_train, y_test = train_test_split(X_features, y_labels, test_size=test_size, shuffle=shuffle)
143139

144140
train_data = [x + [y] for x, y in zip(X_train, y_train)]
145141
test_data = [x + [y] for x, y in zip(X_test, y_test)]

0 commit comments

Comments
 (0)