Skip to content

Commit 07e227e

Browse files
authored
Merge pull request #17 from CompOmics/fix/linting-and-numpy-pin
Remove Numpy version pin; general linting and typing updates
2 parents d82de26 + 2c5a047 commit 07e227e

5 files changed

Lines changed: 192 additions & 189 deletions

File tree

im2deep/calibrate.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
... )
2525
"""
2626

27+
from __future__ import annotations
28+
2729
import logging
28-
from typing import Dict, Union, Optional
30+
from typing import cast
2931

3032
import numpy as np
3133
import pandas as pd
32-
from numpy import ndarray
33-
from psm_utils.peptidoform import Peptidoform
3434

3535
from im2deep._exceptions import CalibrationError
3636

@@ -40,27 +40,28 @@
4040
def _validate_calibration_inputs(
4141
cal_df: pd.DataFrame,
4242
reference_dataset: pd.DataFrame,
43-
required_cal_columns: Optional[list] = None,
44-
required_ref_columns: Optional[list] = None,
43+
required_cal_columns: list | None = None,
44+
required_ref_columns: list | None = None,
4545
) -> None:
4646
"""
4747
Validate input dataframes for calibration functions.
4848
4949
Parameters
5050
----------
51-
cal_df : pd.DataFrame
51+
cal_df
5252
Calibration dataset
53-
reference_dataset : pd.DataFrame
53+
reference_dataset
5454
Reference dataset
55-
required_cal_columns : list, optional
55+
required_cal_columns
5656
Required columns for calibration dataset
57-
required_ref_columns : list, optional
57+
required_ref_columns
5858
Required columns for reference dataset
5959
6060
Raises
6161
------
6262
CalibrationError
6363
If validation fails
64+
6465
"""
6566
if cal_df.empty:
6667
raise CalibrationError("Calibration dataset is empty")
@@ -91,11 +92,11 @@ def get_ccs_shift(
9192
9293
Parameters
9394
----------
94-
cal_df : pd.DataFrame
95+
cal_df
9596
PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed'
96-
reference_dataset : pd.DataFrame
97+
reference_dataset
9798
Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS'
98-
use_charge_state : int, default 2
99+
use_charge_state
99100
Charge state to use for CCS shift calculation. Should be in range [2,4].
100101
101102
Returns
@@ -120,6 +121,7 @@ def get_ccs_shift(
120121
--------
121122
>>> shift = get_ccs_shift(calibration_df, reference_df, use_charge_state=2)
122123
>>> print(f"CCS shift factor: {shift:.2f} Ų")
124+
123125
"""
124126
# Validate inputs
125127
_validate_calibration_inputs(
@@ -187,7 +189,7 @@ def get_ccs_shift(
187189

188190
def get_ccs_shift_per_charge(
189191
cal_df: pd.DataFrame, reference_dataset: pd.DataFrame
190-
) -> Dict[int, float]:
192+
) -> dict[int, float]:
191193
"""
192194
Calculate CCS shift factors per charge state.
193195
@@ -197,9 +199,9 @@ def get_ccs_shift_per_charge(
197199
198200
Parameters
199201
----------
200-
cal_df : pd.DataFrame
202+
cal_df
201203
PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed'
202-
reference_dataset : pd.DataFrame
204+
reference_dataset
203205
Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS'
204206
205207
Returns
@@ -228,6 +230,7 @@ def get_ccs_shift_per_charge(
228230
>>> shifts = get_ccs_shift_per_charge(calibration_df, reference_df)
229231
>>> print(shifts)
230232
{2: 5.2, 3: 3.8, 4: 2.1}
233+
231234
"""
232235
# Validate inputs
233236
_validate_calibration_inputs(
@@ -277,9 +280,7 @@ def get_ccs_shift_per_charge(
277280
# Check for unreasonably large shifts
278281
large_shifts = {k: v for k, v in shift_dict.items() if abs(v) > 100}
279282
if large_shifts:
280-
LOGGER.warning(
281-
f"Large CCS shifts detected: {large_shifts}. " "Please verify data quality."
282-
)
283+
LOGGER.warning(f"Large CCS shifts detected: {large_shifts}. Please verify data quality.")
283284

284285
return shift_dict
285286

@@ -288,8 +289,8 @@ def calculate_ccs_shift(
288289
cal_df: pd.DataFrame,
289290
reference_dataset: pd.DataFrame,
290291
per_charge: bool = True,
291-
use_charge_state: Optional[int] = None,
292-
) -> Union[float, Dict[int, float]]:
292+
use_charge_state: int | None = None,
293+
) -> float | dict[int, float]:
293294
"""
294295
Calculate CCS shift factors with validation and filtering.
295296
@@ -299,14 +300,14 @@ def calculate_ccs_shift(
299300
300301
Parameters
301302
----------
302-
cal_df : pd.DataFrame
303+
cal_df
303304
PSMs with CCS values. Must contain columns: 'sequence', 'charge', 'ccs_observed'
304-
reference_dataset : pd.DataFrame
305+
reference_dataset
305306
Reference dataset with CCS values. Must contain columns: 'peptidoform', 'charge', 'CCS'
306-
per_charge : bool, default True
307+
per_charge
307308
Whether to calculate shift factors per charge state. If False, calculates
308309
a single global shift factor using the specified charge state.
309-
use_charge_state : int, optional
310+
use_charge_state
310311
Charge state to use for global shift calculation when per_charge=False.
311312
Should be in range [2,4]. Default is 2 if not specified.
312313
@@ -334,6 +335,7 @@ def calculate_ccs_shift(
334335
>>>
335336
>>> # Global calibration using charge 2
336337
>>> shift = calculate_ccs_shift(cal_df, ref_df, per_charge=False, use_charge_state=2)
338+
337339
"""
338340
# Validate inputs
339341
_validate_calibration_inputs(cal_df, reference_dataset)
@@ -378,7 +380,7 @@ def linear_calibration(
378380
calibration_dataset: pd.DataFrame,
379381
reference_dataset: pd.DataFrame,
380382
per_charge: bool = True,
381-
use_charge_state: Optional[int] = None,
383+
use_charge_state: int | None = None,
382384
) -> pd.DataFrame:
383385
"""
384386
Calibrate CCS predictions using linear calibration.
@@ -389,20 +391,20 @@ def linear_calibration(
389391
390392
Parameters
391393
----------
392-
preds_df : pd.DataFrame
394+
preds_df
393395
PSMs with CCS predictions. Must contain 'predicted_ccs' column.
394396
Will be modified to include 'charge' and 'shift' columns.
395-
calibration_dataset : pd.DataFrame
397+
calibration_dataset
396398
Calibration dataset with observed CCS values. Must contain columns:
397399
'peptidoform', 'ccs_observed'
398-
reference_dataset : pd.DataFrame
400+
reference_dataset
399401
Reference dataset with CCS values. Must contain columns:
400402
'peptidoform', 'CCS'
401-
per_charge : bool, default True
403+
per_charge
402404
Whether to calculate and apply shift factors per charge state.
403405
If True, uses charge-specific calibration with fallback to global shift.
404406
If False, applies single global shift factor.
405-
use_charge_state : int, optional
407+
use_charge_state
406408
Charge state to use for global shift calculation when per_charge=False.
407409
Default is 2 if not specified.
408410
@@ -446,6 +448,7 @@ def linear_calibration(
446448
... use_charge_state=2
447449
... )
448450
"""
451+
449452
LOGGER.info("Calibrating CCS values using linear calibration...")
450453

451454
# Validate input dataframes
@@ -478,7 +481,7 @@ def linear_calibration(
478481
)
479482

480483
except (AttributeError, ValueError, IndexError) as e:
481-
raise CalibrationError(f"Error parsing peptidoform data: {e}")
484+
raise CalibrationError(f"Error parsing peptidoform data: {e}") from e
482485

483486
if per_charge:
484487
LOGGER.info("Calculating general shift factor for fallback...")
@@ -489,6 +492,8 @@ def linear_calibration(
489492
per_charge=False,
490493
use_charge_state=use_charge_state or 2,
491494
)
495+
# per_charge=False returns float
496+
general_shift = cast(float, general_shift)
492497
except CalibrationError as e:
493498
LOGGER.warning(
494499
f"Could not calculate general shift factor: {e}. Using 0.0 as fallback."
@@ -499,6 +504,8 @@ def linear_calibration(
499504
shift_factor_dict = calculate_ccs_shift(
500505
calibration_dataset, reference_dataset, per_charge=True
501506
)
507+
# per_charge=True returns dict[int, float]
508+
shift_factor_dict = cast(dict[int, float], shift_factor_dict)
502509

503510
# Add charge information to predictions if not present
504511
if "charge" not in preds_df.columns:
@@ -525,6 +532,8 @@ def linear_calibration(
525532
per_charge=False,
526533
use_charge_state=use_charge_state or 2,
527534
)
535+
# per_charge=False returns floats
536+
shift_factor = cast(float, shift_factor)
528537
preds_df["predicted_ccs"] += shift_factor
529538
preds_df["shift"] = shift_factor
530539
LOGGER.info(f"Applied global shift factor: {shift_factor:.3f}")

0 commit comments

Comments
 (0)