From 4fc0051cd803b7eb5a18eb74490046baf03485bf Mon Sep 17 00:00:00 2001 From: KenyaOtsuka Date: Thu, 30 Oct 2025 14:57:13 +0900 Subject: [PATCH 1/3] [feature] add solver option --- pyproject.toml | 1 + src/fastl2lir/fastl2lir.py | 25 +++++++++++++++++++------ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7a33446..b9b375e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ license = { file = "LICENSE" } requires-python = ">=3.1" dependencies = [ "numpy>=1.16.6", + "scipy>=1.2.3", "threadpoolctl>=2.1.0 ; python_full_version >= '3.5'", "tqdm>=4.64.1", ] diff --git a/src/fastl2lir/fastl2lir.py b/src/fastl2lir/fastl2lir.py index c8c60a2..3fb83e8 100644 --- a/src/fastl2lir/fastl2lir.py +++ b/src/fastl2lir/fastl2lir.py @@ -8,6 +8,7 @@ import numpy as np from tqdm import tqdm +from scipy import linalg as sp_linalg pv = sys.version_info @@ -18,10 +19,22 @@ class FastL2LiR(object): '''Fast L2-regularized linear regression class.''' - def __init__(self, W=np.array([]), b=np.array([]), verbose=False): + def __init__(self, W=np.array([]), b=np.array([]), verbose=False, solver='scipy'): self.__W = W self.__b = b self.__verbose = verbose + self.__solver = solver + + # Choose linear solver once to avoid branching at call sites + if solver == 'scipy': + def _solve(a, b): + return sp_linalg.solve(a, b, assume_a='sym', check_finite=False) + elif solver == 'numpy': + def _solve(a, b): + return np.linalg.solve(a, b) + else: + raise ValueError('Unknown solver: %s' % solver) + self.__solve = _solve @property def W(self): @@ -245,10 +258,10 @@ def __sub_fit(self, X, Y, alpha=0, n_feat=0, use_all_features=True, dtype=np.flo # Choose the more efficient method based on matrix dimensions if X.shape[0] > X.shape[1]: # Use primal form for tall matrices (more samples than features) - Wb = np.linalg.solve(np.matmul(X.T, X) + alpha * np.eye(X.shape[1], dtype=dtype), np.matmul(X.T, Y)) + Wb = self.__solve(np.matmul(X.T, X) + alpha * np.eye(X.shape[1], dtype=dtype), np.matmul(X.T, Y)) else: # Use dual form for wide matrices (more features than samples) - Wb = np.matmul(X.T, np.linalg.solve(np.matmul(X, X.T) + alpha * np.eye(X.shape[0], dtype=dtype), Y)) + Wb = np.matmul(X.T, self.__solve(np.matmul(X, X.T) + alpha * np.eye(X.shape[0], dtype=dtype), Y)) W = Wb[0:-1, :] b = Wb[-1, :][np.newaxis, :] # Returning b as a 2D array @@ -274,7 +287,7 @@ def __sub_fit(self, X, Y, alpha=0, n_feat=0, use_all_features=True, dtype=np.flo I = I[0:n_feat] I = np.hstack((I, X.shape[1]-1)) W0_sub = (W0.ravel()[(I + (I * W0.shape[1]).reshape((-1, 1))).ravel()]).reshape(I.size, I.size) - Wb = np.linalg.solve(W0_sub, W1[index_outputDim][I].reshape(-1, 1)) + Wb = self.__solve(W0_sub, W1[index_outputDim][I].reshape(-1, 1)) for index_selectedDim in range(n_feat): W[index_outputDim, I[index_selectedDim]] = Wb[index_selectedDim] b[0, index_outputDim] = Wb[-1] @@ -287,7 +300,7 @@ def __sub_fit(self, X, Y, alpha=0, n_feat=0, use_all_features=True, dtype=np.flo I = I[0:n_feat] I = np.hstack((I, X.shape[1]-1)) W0_sub = (W0.ravel()[(I + (I * W0.shape[1]).reshape((-1,1))).ravel()]).reshape(I.size, I.size) - Wb = np.linalg.solve(W0_sub, W1[index_outputDim][I].reshape(-1,1)) + Wb = self.__solve(W0_sub, W1[index_outputDim][I].reshape(-1,1)) for index_selectedDim in range(n_feat): W[index_outputDim, I[index_selectedDim]] = Wb[index_selectedDim] b[0, index_outputDim] = Wb[-1] @@ -344,7 +357,7 @@ def __sub_fit_save_select_feat( newX = np.hstack((newX, np.ones((newX.shape[0], 1), dtype=dtype))) # Add one column to rightmost column W0 = np.matmul(newX.T, newX) + alpha * np.eye(newX.shape[1], dtype=dtype) W1 = np.matmul(selY.ravel(), newX).reshape(-1,1) - Wb = np.linalg.solve(W0, W1) + Wb = self.__solve(W0, W1) for index_selectedDim in range(n_feat): W[index_outputDim, I[index_selectedDim]] = Wb[index_selectedDim] b[0, index_outputDim] = Wb[-1] From 959824fedac39068da29850b858834f55a75955b Mon Sep 17 00:00:00 2001 From: KenyaOtsuka Date: Thu, 21 May 2026 16:02:22 +0900 Subject: [PATCH 2/3] [test] add tests for numpy solver --- tests/test_fastl2lir.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_fastl2lir.py b/tests/test_fastl2lir.py index 11bd9a2..f28f324 100644 --- a/tests/test_fastl2lir.py +++ b/tests/test_fastl2lir.py @@ -139,6 +139,24 @@ def test_chunk(self): np.testing.assert_array_almost_equal(yp_2d, data['yp_2d']) + def test_solver_numpy_matches_scipy(self): + '''numpy solver produces same result as scipy solver (default).''' + data = np.load('./tests/testdata_basic.npz') + + model_scipy = fastl2lir.FastL2LiR(solver='scipy') + model_numpy = fastl2lir.FastL2LiR(solver='numpy') + + model_scipy.fit(data['x_tr'], data['y_2d']) + model_numpy.fit(data['x_tr'], data['y_2d']) + + np.testing.assert_array_almost_equal(model_numpy.W, model_scipy.W) + np.testing.assert_array_almost_equal(model_numpy.b, model_scipy.b) + + def test_solver_invalid(self): + '''Invalid solver name raises ValueError.''' + with self.assertRaises(ValueError): + fastl2lir.FastL2LiR(solver='cholesky') + def test_reshape(self): '''Test for reshaping.''' Y_shape = (200, 10, 10, 5) From 515f98fd5eb47488330ba0e1c7c74dd493c138d4 Mon Sep 17 00:00:00 2001 From: KenyaOtsuka Date: Thu, 21 May 2026 17:42:49 +0900 Subject: [PATCH 3/3] [misc] try positive-definite solver before symmetric fallback --- src/fastl2lir/fastl2lir.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/fastl2lir/fastl2lir.py b/src/fastl2lir/fastl2lir.py index 3fb83e8..1a9832b 100644 --- a/src/fastl2lir/fastl2lir.py +++ b/src/fastl2lir/fastl2lir.py @@ -28,7 +28,10 @@ def __init__(self, W=np.array([]), b=np.array([]), verbose=False, solver='scipy' # Choose linear solver once to avoid branching at call sites if solver == 'scipy': def _solve(a, b): - return sp_linalg.solve(a, b, assume_a='sym', check_finite=False) + try: + return sp_linalg.solve(a, b, assume_a='pos', check_finite=False) + except sp_linalg.LinAlgError: + return sp_linalg.solve(a, b, assume_a='sym', check_finite=False) elif solver == 'numpy': def _solve(a, b): return np.linalg.solve(a, b)