@@ -19,16 +19,22 @@ class ParallelProcessing: ...
1919from .utils .config import Settings
2020from .utils .algorithm import chunk_split
2121
22- from ._types import ThreadStatus , Data_In , Data_Out , Overflow_In , TargetFunction , HookFunction
22+ from ._types import (
23+ ThreadStatus , Data_In , Data_Out , Overflow_In ,
24+ TargetFunction , _Target_P , _Target_T ,
25+ DatasetFunction , _Dataset_T ,
26+ HookFunction
27+ )
28+ from typing_extensions import Generic
2329from typing import (
24- Any , List ,
25- Callable , Optional ,
30+ Any , List , Unpack ,
31+ Callable , Optional , Union ,
2632 Mapping , Sequence , Tuple
2733)
2834
2935
3036Threads : set ['Thread' ] = set ()
31- class Thread (threading .Thread ):
37+ class Thread (threading .Thread , Generic [ _Target_P , _Target_T ] ):
3238 """
3339 Wraps python's `threading.Thread` class
3440 ---------------------------------------
@@ -51,7 +57,7 @@ class Thread(threading.Thread):
5157
5258 def __init__ (
5359 self ,
54- target : TargetFunction ,
60+ target : TargetFunction [ _Target_P , _Target_T ] ,
5561 args : Sequence [Data_In ] = (),
5662 kwargs : Mapping [str , Data_In ] = {},
5763 ignore_errors : Sequence [type [Exception ]] = (),
@@ -100,10 +106,10 @@ def __init__(
100106 )
101107
102108
103- def _wrap_target (self , target : TargetFunction ) -> TargetFunction :
109+ def _wrap_target (self , target : TargetFunction [ _Target_P , _Target_T ] ) -> TargetFunction [ _Target_P , Union [ _Target_T , None ]] :
104110 """Wraps the target function"""
105111 @wraps (target )
106- def wrapper (* args : Any , ** kwargs : Any ) -> Any :
112+ def wrapper (* args : _Target_P . args , ** kwargs : _Target_P . kwargs ) -> Union [ _Target_T , None ] :
107113 self .status = 'Running'
108114
109115 global Threads
@@ -173,7 +179,7 @@ def _run_with_trace(self) -> None:
173179
174180
175181 @property
176- def result (self ) -> Data_Out :
182+ def result (self ) -> _Target_T :
177183 """
178184 The return value of the thread
179185
@@ -208,7 +214,7 @@ def is_alive(self) -> bool:
208214 return super ().is_alive ()
209215
210216
211- def add_hook (self , hook : HookFunction ) -> None :
217+ def add_hook (self , hook : HookFunction [ _Target_T ] ) -> None :
212218 """
213219 Adds a hook to the thread
214220 -------------------------
@@ -250,7 +256,7 @@ def join(self, timeout: Optional[float] = None) -> bool:
250256 return not self .is_alive ()
251257
252258
253- def get_return_value (self ) -> Data_Out :
259+ def get_return_value (self ) -> _Target_T :
254260 """
255261 Halts the current thread execution until the thread completes
256262
@@ -323,7 +329,7 @@ def __init__(self, thread: Thread, progress: float = 0) -> None:
323329 self .thread = thread
324330 self .progress = progress
325331
326- class ParallelProcessing :
332+ class ParallelProcessing ( Generic [ _Target_P , _Target_T , _Dataset_T ]) :
327333 """
328334 Multi-Threaded Parallel Processing
329335 ---------------------------------------
@@ -335,7 +341,7 @@ class ParallelProcessing:
335341 _completed : int
336342
337343 status : ThreadStatus
338- function : Callable [..., List [Data_Out ]]
344+ function : TargetFunction [..., List [_Target_T ]]
339345 dataset : Sequence [Data_In ]
340346 max_threads : int
341347
@@ -344,8 +350,8 @@ class ParallelProcessing:
344350
345351 def __init__ (
346352 self ,
347- function : TargetFunction ,
348- dataset : Sequence [Data_In ],
353+ function : DatasetFunction [ _Dataset_T , _Target_T ] ,
354+ dataset : Sequence [_Dataset_T ],
349355 max_threads : int = 8 ,
350356
351357 * overflow_args : Overflow_In ,
@@ -385,10 +391,10 @@ def __init__(
385391
386392 def _wrap_function (
387393 self ,
388- function : TargetFunction
389- ) -> Callable [..., List [Data_Out ]]:
394+ function : TargetFunction [[ _Dataset_T ], _Target_T ]
395+ ) -> TargetFunction [..., List [_Target_T ]]:
390396 @wraps (function )
391- def wrapper (index : int , data_chunk : Sequence [Data_In ], * args : Any , ** kwargs : Any ) -> List [Data_Out ]:
397+ def wrapper (index : int , data_chunk : Sequence [_Dataset_T ], * args : _Target_P . args , ** kwargs : _Target_P . kwargs ) -> List [_Target_T ]:
392398 computed : List [Data_Out ] = []
393399 for i , data_entry in enumerate (data_chunk ):
394400 v = function (data_entry , * args , ** kwargs )
@@ -404,7 +410,7 @@ def wrapper(index: int, data_chunk: Sequence[Data_In], *args: Any, **kwargs: Any
404410
405411
406412 @property
407- def results (self ) -> Data_Out :
413+ def results (self ) -> List [ _Dataset_T ] :
408414 """
409415 The return value of the threads if completed
410416
@@ -436,7 +442,7 @@ def is_alive(self) -> bool:
436442 return any (entry .thread .is_alive () for entry in self ._threads )
437443
438444
439- def get_return_values (self ) -> List [Data_Out ]:
445+ def get_return_values (self ) -> List [_Dataset_T ]:
440446 """
441447 Halts the current thread execution until the thread completes
442448
0 commit comments