@@ -455,24 +455,35 @@ def __init__(
455455 """
456456 assert 0 <= max_threads , 'max_threads cannot be set to 0'
457457
458- # Enforce required arguments
459- if not isinstance (dataset , SupportsLength ):
460- assert (
461- _length
462- ), '`_length` must be set if `dataset` does not support `__len__`'
458+ # Impose requirements
459+ if isinstance (dataset , SupportsLengthGetItem ):
460+ _length = _length (dataset ) if callable (_length ) else _length
461+ length = len (dataset ) if _length is None else _length
463462
464- if not hasattr (dataset , '__getitem__' ):
463+ get_value = _get_value or dataset .__class__ .__getitem__
464+
465+ elif isinstance (dataset , SupportsLength ):
465466 assert (
466467 _get_value
467468 ), '`_get_value` must be set if `dataset` does not support `__getitem__`'
469+ _length = _length (dataset ) if callable (_length ) else _length
470+ length = len (dataset ) if _length is None else _length
471+
472+ get_value = _get_value
473+
474+ elif isinstance (dataset , SupportsGetItem ):
475+ assert (
476+ _length
477+ ), '`_length` must be set if `dataset` does not support `__len__`'
478+ length = _length (dataset ) if callable (_length ) else _length
468479
469- _length = _length (dataset ) if callable (_length ) else _length
470- _length = len (dataset ) if isinstance (dataset , SupportsLength ) else _length
480+ get_value = _get_value or dataset .__class__ .__getitem__
471481
472- assert isinstance (_length , int ), '`_length` must be an integer'
473- assert _length > 0 , 'dataset cannot be empty'
482+ assert isinstance (length , int ), '`_length` must be an integer'
483+ assert length > 0 , 'dataset cannot be empty'
484+ assert get_value , '`_get_value` must be set'
474485
475- self ._length = _length
486+ self ._length = length
476487 self ._threads = []
477488 self ._completed = 0
478489
0 commit comments