@@ -33,9 +33,23 @@ class ParallelProcessing: ...
3333 DatasetFunction ,
3434 _Dataset_T ,
3535 HookFunction ,
36+ SupportsLength ,
37+ SupportsGetItem ,
38+ SupportsLengthGetItem ,
3639)
3740from typing_extensions import Generic
38- from typing import List , Optional , Union , Mapping , Sequence , Tuple , Generator
41+ from typing import (
42+ Any ,
43+ List ,
44+ Optional ,
45+ Union ,
46+ Mapping ,
47+ Sequence ,
48+ Tuple ,
49+ Callable ,
50+ Generator ,
51+ overload ,
52+ )
3953
4054
4155Threads : set ['Thread' ] = set ()
@@ -337,18 +351,68 @@ class ParallelProcessing(Generic[_Target_P, _Target_T, _Dataset_T]):
337351
338352 status : ThreadStatus
339353 function : TargetFunction
340- dataset : Sequence [Data_In ]
354+ dataset : Union [
355+ Sequence [_Dataset_T ],
356+ SupportsLength ,
357+ SupportsGetItem [_Dataset_T ],
358+ SupportsLengthGetItem [_Dataset_T ],
359+ ]
341360 max_threads : int
342361
343362 overflow_args : Sequence [Overflow_In ]
344363 overflow_kwargs : Mapping [str , Overflow_In ]
345364
365+ @overload
366+ def __init__ (
367+ self ,
368+ function : DatasetFunction [_Dataset_T , _Target_P , _Target_T ],
369+ dataset : Union [Sequence [_Dataset_T ], SupportsLengthGetItem [_Dataset_T ]],
370+ max_threads : int = 8 ,
371+ * overflow_args : Overflow_In ,
372+ _get_value : Optional [Callable [[Sequence [_Dataset_T ], int ], _Dataset_T ]] = None ,
373+ _length : Optional [Union [int , Callable [[Sequence [_Dataset_T ]], int ]]] = None ,
374+ ** overflow_kwargs : Overflow_In ,
375+ ) -> None : ...
376+
377+ # Has __len__, require _get_value to be set
378+ @overload
379+ def __init__ (
380+ self ,
381+ function : DatasetFunction [_Dataset_T , _Target_P , _Target_T ],
382+ dataset : SupportsLength ,
383+ max_threads : int = 8 ,
384+ * overflow_args : Overflow_In ,
385+ _get_value : Callable [[Sequence [_Dataset_T ], int ], _Dataset_T ],
386+ _length : Optional [Union [int , Callable [[Sequence [_Dataset_T ]], int ]]] = None ,
387+ ** overflow_kwargs : Overflow_In ,
388+ ) -> None : ...
389+
390+ # Has __getitem__, require _length to be set
391+ @overload
392+ def __init__ (
393+ self ,
394+ function : DatasetFunction [_Dataset_T , _Target_P , _Target_T ],
395+ dataset : SupportsLength ,
396+ max_threads : int = 8 ,
397+ * overflow_args : Overflow_In ,
398+ _get_value : Optional [Callable [[Sequence [_Dataset_T ], int ], _Dataset_T ]] = None ,
399+ _length : Union [int , Callable [[Sequence [_Dataset_T ]], int ]],
400+ ** overflow_kwargs : Overflow_In ,
401+ ) -> None : ...
402+
346403 def __init__ (
347404 self ,
348405 function : DatasetFunction [_Dataset_T , _Target_P , _Target_T ],
349- dataset : Sequence [_Dataset_T ],
406+ dataset : Union [
407+ Sequence [_Dataset_T ],
408+ SupportsLength ,
409+ SupportsGetItem [_Dataset_T ],
410+ SupportsLengthGetItem [_Dataset_T ],
411+ ],
350412 max_threads : int = 8 ,
351413 * overflow_args : Overflow_In ,
414+ _get_value : Optional [Callable [[Sequence [_Dataset_T ], int ], _Dataset_T ]] = None ,
415+ _length : Optional [Union [int , Callable [[Sequence [_Dataset_T ]], int ]]] = None ,
352416 ** overflow_kwargs : Overflow_In ,
353417 ) -> None :
354418 """
@@ -363,6 +427,8 @@ def __init__(
363427 :param dataset: This should be an iterable sequence of data entries
364428 :param max_threads: This should be an integer value of the max threads allowed
365429 :param *: These are arguments parsed to `threading.Thread` and `Thread`
430+ :param _get_value: This should be a function that takes in the dataset and the index and returns the data entry
431+ :param _length: This should be an integer or a function that takes in the dataset and returns the length
366432 :param **: These are arguments parsed to `thread.Thread` and `Thread`
367433
368434 Raises
0 commit comments