Skip to content

Commit 61cd504

Browse files
committed
Allocate device_ptr memory only for sizes > 0 on all backends (fixing test errors on CPUs).
1 parent 1caf833 commit 61cd504

4 files changed

Lines changed: 6 additions & 8 deletions

File tree

src/plssvm/backends/CUDA/detail/device_ptr.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ device_ptr<T>::device_ptr(const plssvm::shape shape, const plssvm::shape padding
4040
if (queue_ < 0 || queue_ >= get_device_count()) {
4141
throw backend_exception{ fmt::format("Illegal device ID! Must be in range: [0, {}) but is {}.", get_device_count(), queue_) };
4242
}
43-
detail::set_device(queue_);
44-
PLSSVM_CUDA_ERROR_CHECK(cudaMalloc(&data_, this->size_padded() * sizeof(value_type)))
4543

4644
// only non-empty pointers must be memset in the constructor
4745
if (this->size_padded() != std::size_t{ 0 }) {
46+
detail::set_device(queue_);
47+
PLSSVM_CUDA_ERROR_CHECK(cudaMalloc(&data_, this->size_padded() * sizeof(value_type)))
4848
this->memset(0);
4949
}
5050
}

src/plssvm/backends/HIP/detail/device_ptr.hip

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ device_ptr<T>::device_ptr(const plssvm::shape shape, const plssvm::shape padding
4242
if (queue_ < 0 || queue_ >= get_device_count()) {
4343
throw backend_exception{ fmt::format("Illegal device ID! Must be in range: [0, {}) but is {}.", get_device_count(), queue_) };
4444
}
45-
detail::set_device(queue_);
46-
PLSSVM_HIP_ERROR_CHECK(hipMalloc(&data_, this->size_padded() * sizeof(value_type)))
4745

4846
// only non-empty pointers must be memset in the constructor
4947
if (this->size_padded() != std::size_t{ 0 }) {
48+
detail::set_device(queue_);
49+
PLSSVM_HIP_ERROR_CHECK(hipMalloc(&data_, this->size_padded() * sizeof(value_type)))
5050
this->memset(0);
5151
}
5252
}

src/plssvm/backends/SYCL/AdaptiveCpp/detail/device_ptr.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,9 @@ device_ptr<T>::device_ptr(const plssvm::shape shape, const queue &q) :
3838
template <typename T>
3939
device_ptr<T>::device_ptr(const plssvm::shape shape, const plssvm::shape padding, const queue &q) :
4040
base_type{ shape, padding, q } {
41-
data_ = ::sycl::malloc_device<value_type>(this->size_padded(), queue_.impl->sycl_queue);
42-
4341
// only non-empty pointers must be memset in the constructor
4442
if (this->size_padded() != std::size_t{ 0 }) {
43+
data_ = ::sycl::malloc_device<value_type>(this->size_padded(), queue_.impl->sycl_queue);
4544
this->memset(0);
4645
}
4746
}

src/plssvm/backends/SYCL/DPCPP/detail/device_ptr.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,9 @@ device_ptr<T>::device_ptr(const plssvm::shape shape, const queue &q) :
3939
template <typename T>
4040
device_ptr<T>::device_ptr(const plssvm::shape shape, plssvm::shape padding, const queue &q) :
4141
base_type{ shape, padding, q } {
42-
data_ = ::sycl::malloc_device<value_type>(this->size_padded(), queue_.impl->sycl_queue);
43-
4442
// only non-empty pointers must be memset in the constructor
4543
if (this->size_padded() != std::size_t{ 0 }) {
44+
data_ = ::sycl::malloc_device<value_type>(this->size_padded(), queue_.impl->sycl_queue);
4645
this->memset(0);
4746
}
4847
}

0 commit comments

Comments
 (0)