Skip to content

Commit ba1dad3

Browse files
feat: Impose requirements
Ref: #67
1 parent 5d4da34 commit ba1dad3

1 file changed

Lines changed: 22 additions & 11 deletions

File tree

src/thread/thread.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -455,24 +455,35 @@ def __init__(
455455
"""
456456
assert 0 <= max_threads, 'max_threads cannot be set to 0'
457457

458-
# Enforce required arguments
459-
if not isinstance(dataset, SupportsLength):
460-
assert (
461-
_length
462-
), '`_length` must be set if `dataset` does not support `__len__`'
458+
# Impose requirements
459+
if isinstance(dataset, SupportsLengthGetItem):
460+
_length = _length(dataset) if callable(_length) else _length
461+
length = len(dataset) if _length is None else _length
463462

464-
if not hasattr(dataset, '__getitem__'):
463+
get_value = _get_value or dataset.__class__.__getitem__
464+
465+
elif isinstance(dataset, SupportsLength):
465466
assert (
466467
_get_value
467468
), '`_get_value` must be set if `dataset` does not support `__getitem__`'
469+
_length = _length(dataset) if callable(_length) else _length
470+
length = len(dataset) if _length is None else _length
471+
472+
get_value = _get_value
473+
474+
elif isinstance(dataset, SupportsGetItem):
475+
assert (
476+
_length
477+
), '`_length` must be set if `dataset` does not support `__len__`'
478+
length = _length(dataset) if callable(_length) else _length
468479

469-
_length = _length(dataset) if callable(_length) else _length
470-
_length = len(dataset) if isinstance(dataset, SupportsLength) else _length
480+
get_value = _get_value or dataset.__class__.__getitem__
471481

472-
assert isinstance(_length, int), '`_length` must be an integer'
473-
assert _length > 0, 'dataset cannot be empty'
482+
assert isinstance(length, int), '`_length` must be an integer'
483+
assert length > 0, 'dataset cannot be empty'
484+
assert get_value, '`_get_value` must be set'
474485

475-
self._length = _length
486+
self._length = length
476487
self._threads = []
477488
self._completed = 0
478489

0 commit comments

Comments
 (0)