Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/functional/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ def get_uses_mapping():
infer_timeouts = {
TargetDevice.CPU: default_infer_timeout,
TargetDevice.GPU: default_gpu_infer_timeout,
TargetDevice.GPU_0: default_gpu_infer_timeout,
TargetDevice.GPU_1: default_gpu_infer_timeout,
TargetDevice.GPU_2: default_gpu_infer_timeout,
TargetDevice.NPU: default_npu_infer_timeout,
TargetDevice.AUTO: default_gpu_infer_timeout,
TargetDevice.HETERO: default_gpu_infer_timeout,
Expand Down
3 changes: 3 additions & 0 deletions tests/functional/constants/target_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
class TargetDevice:
CPU = "CPU"
GPU = "GPU"
GPU_0 = "GPU:0"
GPU_1 = "GPU:1"
GPU_2 = "GPU:2"
NPU = "NPU"
AUTO = "AUTO:GPU,CPU"
HETERO = "HETERO:GPU,CPU"
Expand Down
27 changes: 25 additions & 2 deletions tests/functional/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,27 @@ def get_multi_target_devices(target_devices_list, separator):
return result


def _is_device_with_index(device_str):
""" Check if device string is a device with numeric index, e.g. GPU:0, GPU:1. """
if ":" in device_str:
_, suffix = device_str.split(":", 1)
return suffix.isdigit()
return False


def validate_supported_values(detected_list, supported_list):
supported_list += ALL_AVAILABLE_OPTIONS # 'starred expression' will be evaluated during pytest_configure
check = all(_elem in supported_list for _elem in detected_list)

def _is_supported(device):
if device in supported_list:
return True
# Accept indexed devices like GPU:0, GPU:1 if base device (GPU) is supported
if _is_device_with_index(device):
base_device = device.split(":", 1)[0]
return base_device in supported_list
return False

check = all(_is_supported(_elem) for _elem in detected_list)
assert check, f"Not supported target devices in {detected_list}"
return detected_list

Expand All @@ -106,7 +124,12 @@ def get_target_devices():
""" Convert comma separated string of devices into list """
target_devices_list = get_list("TT_TARGET_DEVICE", fallback=[TargetDevice.CPU])
separator_multi = ":"
if any(separator_multi in _target_device for _target_device in target_devices_list):
# Only treat as multi-target if ':' is followed by a device name, not a numeric index (GPU:0, GPU:1)
has_multi_target = any(
separator_multi in _td and not _is_device_with_index(_td)
for _td in target_devices_list
)
if has_multi_target:
target_devices_list = get_multi_target_devices(target_devices_list, separator_multi)
ov_target_devices = [value for key, value in vars(TargetDevice).items() if not key.startswith("__")]
target_devices_list = validate_supported_values(detected_list=target_devices_list, supported_list=ov_target_devices)
Expand Down