Skip to content

Commit 0a2577f

Browse files
+ Improved type safety with ParamSpec and TypeVar
1 parent 11e6074 commit 0a2577f

2 files changed

Lines changed: 33 additions & 21 deletions

File tree

src/thread/_types.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from typing import Any, Literal, Callable, Union
8+
from typing_extensions import ParamSpec, TypeVar
89

910

1011
# Descriptive Types
@@ -27,5 +28,10 @@
2728

2829

2930
# Function types
30-
HookFunction = Callable[[Data_Out], Union[Any, None]]
31-
TargetFunction = Callable[..., Data_Out]
31+
_Target_P = ParamSpec('_Target_P')
32+
_Target_T = TypeVar('_Target_T')
33+
_Dataset_T = TypeVar('_Dataset_T')
34+
35+
TargetFunction = Callable[_Target_P, _Target_T]
36+
37+
HookFunction = Callable[[_Target_T], Union[Any, None]]

src/thread/thread.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,22 @@ class ParallelProcessing: ...
1919
from .utils.config import Settings
2020
from .utils.algorithm import chunk_split
2121

22-
from ._types import ThreadStatus, Data_In, Data_Out, Overflow_In, TargetFunction, HookFunction
22+
from ._types import (
23+
ThreadStatus, Data_In, Data_Out, Overflow_In,
24+
TargetFunction, _Target_P, _Target_T,
25+
DatasetFunction, _Dataset_T,
26+
HookFunction
27+
)
28+
from typing_extensions import Generic
2329
from typing import (
24-
Any, List,
25-
Callable, Optional,
30+
Any, List, Unpack,
31+
Callable, Optional, Union,
2632
Mapping, Sequence, Tuple
2733
)
2834

2935

3036
Threads: set['Thread'] = set()
31-
class Thread(threading.Thread):
37+
class Thread(threading.Thread, Generic[_Target_P, _Target_T]):
3238
"""
3339
Wraps python's `threading.Thread` class
3440
---------------------------------------
@@ -51,7 +57,7 @@ class Thread(threading.Thread):
5157

5258
def __init__(
5359
self,
54-
target: TargetFunction,
60+
target: TargetFunction[_Target_P, _Target_T],
5561
args: Sequence[Data_In] = (),
5662
kwargs: Mapping[str, Data_In] = {},
5763
ignore_errors: Sequence[type[Exception]] = (),
@@ -100,10 +106,10 @@ def __init__(
100106
)
101107

102108

103-
def _wrap_target(self, target: TargetFunction) -> TargetFunction:
109+
def _wrap_target(self, target: TargetFunction[_Target_P, _Target_T]) -> TargetFunction[_Target_P, Union[_Target_T, None]]:
104110
"""Wraps the target function"""
105111
@wraps(target)
106-
def wrapper(*args: Any, **kwargs: Any) -> Any:
112+
def wrapper(*args: _Target_P.args, **kwargs: _Target_P.kwargs) -> Union[_Target_T, None]:
107113
self.status = 'Running'
108114

109115
global Threads
@@ -173,7 +179,7 @@ def _run_with_trace(self) -> None:
173179

174180

175181
@property
176-
def result(self) -> Data_Out:
182+
def result(self) -> _Target_T:
177183
"""
178184
The return value of the thread
179185
@@ -208,7 +214,7 @@ def is_alive(self) -> bool:
208214
return super().is_alive()
209215

210216

211-
def add_hook(self, hook: HookFunction) -> None:
217+
def add_hook(self, hook: HookFunction[_Target_T]) -> None:
212218
"""
213219
Adds a hook to the thread
214220
-------------------------
@@ -250,7 +256,7 @@ def join(self, timeout: Optional[float] = None) -> bool:
250256
return not self.is_alive()
251257

252258

253-
def get_return_value(self) -> Data_Out:
259+
def get_return_value(self) -> _Target_T:
254260
"""
255261
Halts the current thread execution until the thread completes
256262
@@ -323,7 +329,7 @@ def __init__(self, thread: Thread, progress: float = 0) -> None:
323329
self.thread = thread
324330
self.progress = progress
325331

326-
class ParallelProcessing:
332+
class ParallelProcessing(Generic[_Target_P, _Target_T, _Dataset_T]):
327333
"""
328334
Multi-Threaded Parallel Processing
329335
---------------------------------------
@@ -335,7 +341,7 @@ class ParallelProcessing:
335341
_completed : int
336342

337343
status : ThreadStatus
338-
function : Callable[..., List[Data_Out]]
344+
function : TargetFunction[..., List[_Target_T]]
339345
dataset : Sequence[Data_In]
340346
max_threads : int
341347

@@ -344,8 +350,8 @@ class ParallelProcessing:
344350

345351
def __init__(
346352
self,
347-
function: TargetFunction,
348-
dataset: Sequence[Data_In],
353+
function: DatasetFunction[_Dataset_T, _Target_T],
354+
dataset: Sequence[_Dataset_T],
349355
max_threads: int = 8,
350356

351357
*overflow_args: Overflow_In,
@@ -385,10 +391,10 @@ def __init__(
385391

386392
def _wrap_function(
387393
self,
388-
function: TargetFunction
389-
) -> Callable[..., List[Data_Out]]:
394+
function: TargetFunction[[_Dataset_T], _Target_T]
395+
) -> TargetFunction[..., List[_Target_T]]:
390396
@wraps(function)
391-
def wrapper(index: int, data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any) -> List[Data_Out]:
397+
def wrapper(index: int, data_chunk: Sequence[_Dataset_T], *args: _Target_P.args, **kwargs: _Target_P.kwargs) -> List[_Target_T]:
392398
computed: List[Data_Out] = []
393399
for i, data_entry in enumerate(data_chunk):
394400
v = function(data_entry, *args, **kwargs)
@@ -404,7 +410,7 @@ def wrapper(index: int, data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any
404410

405411

406412
@property
407-
def results(self) -> Data_Out:
413+
def results(self) -> List[_Dataset_T]:
408414
"""
409415
The return value of the threads if completed
410416
@@ -436,7 +442,7 @@ def is_alive(self) -> bool:
436442
return any(entry.thread.is_alive() for entry in self._threads)
437443

438444

439-
def get_return_values(self) -> List[Data_Out]:
445+
def get_return_values(self) -> List[_Dataset_T]:
440446
"""
441447
Halts the current thread execution until the thread completes
442448

0 commit comments

Comments
 (0)