Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 159 additions & 1 deletion src/native/cambricon/common.h
Comment thread
bitzyz marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace infini::ops::reduce {

constexpr int batch_size = 128 / sizeof(float);

__mlu_func__ void SumInternal(float* dst, float* src, int max_batch) {
__mlu_func__ void SumInternal(float* src, float* dst, int max_batch) {
const int width = max_batch / batch_size;

if (width >= 4) {
Expand All @@ -30,6 +30,164 @@ __mlu_func__ void SumInternal(float* dst, float* src, int max_batch) {
}
}

template <typename T>
__mlu_func__ void SumTyped(T* data, float* result, size_t len) {
if constexpr (std::is_same_v<T, __half>) {
__bang_half2float((float*)data, reinterpret_cast<half*>(data) + len, len);
SumInternal((float*)data, result, len);
} else if constexpr (std::is_same_v<T, __bang_bfloat16>) {
__bang_bfloat162float((float*)data, data + len, len);
SumInternal((float*)data, result, len);
} else {
SumInternal(data, result, len);
}
}

template <typename T>
__mlu_func__ float Sum(const T* source, T* src, float* dst, int num_elements,
int max_batch) {
float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);

size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);

if (curr_batch < max_batch) {
__bang_write_value(src, max_batch + offset, 0);
}

__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
GDRAM2NRAM);
SumTyped(src, dst, max_batch);
res += dst[0];
processed += curr_batch;
}

return res;
}

template <typename T>
__mlu_func__ float SumBatched(const T* source, T* src, float* dst,
int num_elements, int max_batch) {
constexpr int min_vector_size = 32;

if (num_elements < min_vector_size) {
return Sum(source, src, dst, num_elements, max_batch);
}

float res = 0.0f;
int offset = (sizeof(T) == 2 ? max_batch : 0);

size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);
size_t aligned_batch = (curr_batch / batch_size) * batch_size;
size_t remainder = curr_batch % batch_size;

// Ensure NRAM buffer is zeroed.
__bang_write_value(src, max_batch + offset, 0);

// Copy data to NRAM.
__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
GDRAM2NRAM);

if constexpr (std::is_same_v<T, __half>) {
__bang_half2float((float*)(src + offset),
reinterpret_cast<half*>(src) + offset, curr_batch);
} else if constexpr (std::is_same_v<T, __bang_bfloat16>) {
__bang_bfloat162float((float*)(src + offset), src + offset, curr_batch);
}

if (aligned_batch > 0) {
SumInternal((float*)(src + offset), dst, aligned_batch);
res += dst[0];
}
if (remainder > 0) {
for (size_t i = aligned_batch; i < curr_batch; ++i) {
res += ((float*)(src + offset))[i];
}
}

processed += curr_batch;
}

return res;
}

__mlu_func__ void MaxInternal(float* src, float* dst, int max_batch) {
__bang_maxpool(dst, src, batch_size, 1, max_batch / batch_size, 1,
max_batch / batch_size, 1, 1);
__bang_argmax(dst, dst, batch_size);
}

template <typename T>
__mlu_func__ void MaxTyped(T* data, float* result, size_t len) {
if constexpr (std::is_same_v<T, __half>) {
__bang_half2float((float*)data, reinterpret_cast<half*>(data) + len, len);
MaxInternal((float*)data, result, len);
} else if constexpr (std::is_same_v<T, __bang_bfloat16>) {
__bang_bfloat162float((float*)data, data + len, len);
MaxInternal((float*)data, result, len);
} else {
MaxInternal(data, result, len);
}
}

template <typename T>
__mlu_func__ float Max(const T* source, T* src, float* dst, int num_elements,
int max_batch) {
float max_val = -INFINITY;
int offset = (sizeof(T) == 2 ? max_batch : 0);

size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);

if (curr_batch < max_batch) {
__bang_write_value(src, max_batch + offset, 0);
}

__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
GDRAM2NRAM);
MaxTyped(src, dst, max_batch);
max_val = std::max(max_val, dst[0]);
processed += curr_batch;
}

return max_val;
}

template <typename T>
__mlu_func__ float MaxBatched(const T* source, T* src, float* dst,
int num_elements, int max_batch) {
constexpr int min_vector_size = 32;

if (num_elements < min_vector_size) {
return Max(source, src, dst, num_elements, max_batch);
}

float max_val = -INFINITY;
int offset = (sizeof(T) == 2 ? max_batch : 0);

size_t processed = 0;
while (processed < num_elements) {
size_t curr_batch = std::min<size_t>(max_batch, num_elements - processed);

if (curr_batch < max_batch) {
__bang_write_value(src, max_batch + offset, 0);
}

__memcpy(src + offset, source + processed, curr_batch * sizeof(T),
GDRAM2NRAM);
MaxTyped(src, dst, max_batch);
max_val = std::max(max_val, dst[0]);
processed += curr_batch;
}

return max_val;
}

} // namespace infini::ops::reduce

#endif // __BANG__
Expand Down
63 changes: 63 additions & 0 deletions src/native/cambricon/ops/causal_softmax/causal_softmax.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#ifndef INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H
#define INFINI_OPS_CAMBRICON_CAUSAL_SOFTMAX_H

#include "base/causal_softmax.h"
#include "native/cambricon/common.h"
#include "native/cambricon/data_type_.h"

namespace infini::ops {

// TODO: Remove forward declaration.
template <typename T>
void CausalSoftmaxUnion(void* workspace, int core_per_cluster,
int cluster_count, cnrtQueue_t queue, const void* x,
void* y, size_t batch_size_, size_t seq_len_,
size_t total_seq_len_, ptrdiff_t y_stride_b,
ptrdiff_t y_stride_i, ptrdiff_t y_stride_j,
ptrdiff_t x_stride_b, ptrdiff_t x_stride_i,
ptrdiff_t x_stride_j);

template <>
class Operator<CausalSoftmax, Device::Type::kCambricon> : public CausalSoftmax {
public:
Operator(const Tensor input, Tensor out) : CausalSoftmax{input, out} {
cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster,
&cluster_count);
}

void operator()(const Tensor input, Tensor out) const override {
Comment thread
bitzyz marked this conversation as resolved.
auto queue = static_cast<cnrtQueue_t>(stream_ ? stream_ : 0);
auto workspace{workspace_ ? workspace_ : default_workspace_};
ptrdiff_t y_stride_b = ndim_ == 3 ? out_strides_[0] : 1;
ptrdiff_t y_stride_i = ndim_ == 3 ? out_strides_[1] : out_strides_[0];
ptrdiff_t y_stride_j = ndim_ == 3 ? out_strides_[2] : out_strides_[1];
ptrdiff_t x_stride_b = ndim_ == 3 ? input_strides_[0] : 1;
ptrdiff_t x_stride_i = ndim_ == 3 ? input_strides_[1] : input_strides_[0];
ptrdiff_t x_stride_j = ndim_ == 3 ? input_strides_[2] : input_strides_[1];

DispatchFunc<
List<DataType::kFloat16, DataType::kBFloat16, DataType::kFloat32>>(
{static_cast<int64_t>(input.dtype())},
[&](auto input_tag) {
using InputT = infini::ops::TypeMapType<Device::Type::kCambricon,
ListGet<0>(input_tag)>;
CausalSoftmaxUnion<InputT>(
workspace, core_per_cluster, cluster_count, queue, input.data(),
out.data(), batch_size_, seq_len_, total_seq_len_, y_stride_b,
y_stride_i, y_stride_j, x_stride_b, x_stride_i, x_stride_j);
},
"CambriconCausalSoftmax::operator() - output dispatch");
}

std::size_t workspace_size_in_bytes() const override { return 0; }

~Operator() {}

void* default_workspace_{nullptr};
int core_per_cluster = 0;
int cluster_count = 0;
};

} // namespace infini::ops

#endif
Loading
Loading