Skip to content

Commit 7bcba2a

Browse files
feat: Overloading ParallelProcessing init
1 parent 1d453e0 commit 7bcba2a

1 file changed

Lines changed: 69 additions & 3 deletions

File tree

src/thread/thread.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,23 @@ class ParallelProcessing: ...
3333
DatasetFunction,
3434
_Dataset_T,
3535
HookFunction,
36+
SupportsLength,
37+
SupportsGetItem,
38+
SupportsLengthGetItem,
3639
)
3740
from typing_extensions import Generic
38-
from typing import List, Optional, Union, Mapping, Sequence, Tuple, Generator
41+
from typing import (
42+
Any,
43+
List,
44+
Optional,
45+
Union,
46+
Mapping,
47+
Sequence,
48+
Tuple,
49+
Callable,
50+
Generator,
51+
overload,
52+
)
3953

4054

4155
Threads: set['Thread'] = set()
@@ -337,18 +351,68 @@ class ParallelProcessing(Generic[_Target_P, _Target_T, _Dataset_T]):
337351

338352
status: ThreadStatus
339353
function: TargetFunction
340-
dataset: Sequence[Data_In]
354+
dataset: Union[
355+
Sequence[_Dataset_T],
356+
SupportsLength,
357+
SupportsGetItem[_Dataset_T],
358+
SupportsLengthGetItem[_Dataset_T],
359+
]
341360
max_threads: int
342361

343362
overflow_args: Sequence[Overflow_In]
344363
overflow_kwargs: Mapping[str, Overflow_In]
345364

365+
@overload
366+
def __init__(
367+
self,
368+
function: DatasetFunction[_Dataset_T, _Target_P, _Target_T],
369+
dataset: Union[Sequence[_Dataset_T], SupportsLengthGetItem[_Dataset_T]],
370+
max_threads: int = 8,
371+
*overflow_args: Overflow_In,
372+
_get_value: Optional[Callable[[Sequence[_Dataset_T], int], _Dataset_T]] = None,
373+
_length: Optional[Union[int, Callable[[Sequence[_Dataset_T]], int]]] = None,
374+
**overflow_kwargs: Overflow_In,
375+
) -> None: ...
376+
377+
# Has __len__, require _get_value to be set
378+
@overload
379+
def __init__(
380+
self,
381+
function: DatasetFunction[_Dataset_T, _Target_P, _Target_T],
382+
dataset: SupportsLength,
383+
max_threads: int = 8,
384+
*overflow_args: Overflow_In,
385+
_get_value: Callable[[Sequence[_Dataset_T], int], _Dataset_T],
386+
_length: Optional[Union[int, Callable[[Sequence[_Dataset_T]], int]]] = None,
387+
**overflow_kwargs: Overflow_In,
388+
) -> None: ...
389+
390+
# Has __getitem__, require _length to be set
391+
@overload
392+
def __init__(
393+
self,
394+
function: DatasetFunction[_Dataset_T, _Target_P, _Target_T],
395+
dataset: SupportsLength,
396+
max_threads: int = 8,
397+
*overflow_args: Overflow_In,
398+
_get_value: Optional[Callable[[Sequence[_Dataset_T], int], _Dataset_T]] = None,
399+
_length: Union[int, Callable[[Sequence[_Dataset_T]], int]],
400+
**overflow_kwargs: Overflow_In,
401+
) -> None: ...
402+
346403
def __init__(
347404
self,
348405
function: DatasetFunction[_Dataset_T, _Target_P, _Target_T],
349-
dataset: Sequence[_Dataset_T],
406+
dataset: Union[
407+
Sequence[_Dataset_T],
408+
SupportsLength,
409+
SupportsGetItem[_Dataset_T],
410+
SupportsLengthGetItem[_Dataset_T],
411+
],
350412
max_threads: int = 8,
351413
*overflow_args: Overflow_In,
414+
_get_value: Optional[Callable[[Sequence[_Dataset_T], int], _Dataset_T]] = None,
415+
_length: Optional[Union[int, Callable[[Sequence[_Dataset_T]], int]]] = None,
352416
**overflow_kwargs: Overflow_In,
353417
) -> None:
354418
"""
@@ -363,6 +427,8 @@ def __init__(
363427
:param dataset: This should be an iterable sequence of data entries
364428
:param max_threads: This should be an integer value of the max threads allowed
365429
:param *: These are arguments parsed to `threading.Thread` and `Thread`
430+
:param _get_value: This should be a function that takes in the dataset and the index and returns the data entry
431+
:param _length: This should be an integer or a function that takes in the dataset and returns the length
366432
:param **: These are arguments parsed to `thread.Thread` and `Thread`
367433
368434
Raises

0 commit comments

Comments
 (0)