Skip to content

Commit abdba00

Browse files
committed
Fix _assert_equal check
1 parent 52023c9 commit abdba00

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

tests/test_sklearn.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,7 +1010,11 @@ def _assert_equal(original_model, loaded_model):
10101010
if check_if_fit(loaded_model):
10111011
_assert_same_state_dict(original_model.state_dict_,
10121012
loaded_model.state_dict_)
1013-
X = np.random.normal(0, 1, (100, 1))
1013+
1014+
n_features = loaded_model.n_features_
1015+
if isinstance(n_features, list):
1016+
n_features = n_features[0]
1017+
X = np.random.normal(0, 1, (100, n_features))
10141018

10151019
if loaded_model.num_sessions is not None:
10161020
assert np.allclose(loaded_model.transform(X, session_id=0),

0 commit comments

Comments
 (0)