Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions gen/arrays.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,18 +679,73 @@ DSliceValue *DtoAppendDCharToUnicodeString(Loc loc, DValue *arr,

////////////////////////////////////////////////////////////////////////////////
namespace {
// Create a call instruction to memcmp.
llvm::CallInst *callMemcmp(Loc loc, IRState &irs, LLValue *l_ptr,
LLValue *r_ptr, LLValue *numElements, LLType *elemty) {
// Inline byte-wise memcmp for `@compute` device code. The device has no libc to
// provide `memcmp`, and there is no `llvm.memcmp` intrinsic, so emit the
// comparison as a loop. Returns an i32 with C `memcmp`'s convention: 0 iff the
// `sizeInBytes` bytes are equal, otherwise the (unsigned) difference of the
// first mismatching byte.
llvm::Value *emitInlineMemcmp(IRState &irs, LLValue *l_ptr, LLValue *r_ptr,
LLValue *sizeInBytes) {
LLType *i8 = LLType::getInt8Ty(gIR->context());
LLType *i32 = LLType::getInt32Ty(gIR->context());

llvm::BasicBlock *entryBB = irs.ir->GetInsertBlock();
llvm::BasicBlock *condBB = irs.insertBB("memcmp.cond");
llvm::BasicBlock *bodyBB = irs.insertBBAfter(condBB, "memcmp.body");
llvm::BasicBlock *diffBB = irs.insertBBAfter(bodyBB, "memcmp.diff");
llvm::BasicBlock *incBB = irs.insertBBAfter(diffBB, "memcmp.inc");
llvm::BasicBlock *endBB = irs.insertBBAfter(incBB, "memcmp.end");
irs.ir->CreateBr(condBB);

// cond: index < sizeInBytes ?
irs.ir->SetInsertPoint(condBB);
llvm::PHINode *idx = irs.ir->CreatePHI(sizeInBytes->getType(), 2);
idx->addIncoming(DtoConstSize_t(0), entryBB);
irs.ir->CreateCondBr(
irs.ir->CreateICmp(llvm::ICmpInst::ICMP_ULT, idx, sizeInBytes), bodyBB,
endBB);

// body: compare the two bytes at `idx`
irs.ir->SetInsertPoint(bodyBB);
LLValue *lb = irs.ir->CreateLoad(i8, irs.ir->CreateInBoundsGEP(i8, l_ptr, idx));
LLValue *rb = irs.ir->CreateLoad(i8, irs.ir->CreateInBoundsGEP(i8, r_ptr, idx));
irs.ir->CreateCondBr(irs.ir->CreateICmp(llvm::ICmpInst::ICMP_EQ, lb, rb), incBB,
diffBB);

// diff: result = (unsigned)lb - (unsigned)rb
irs.ir->SetInsertPoint(diffBB);
LLValue *diff = irs.ir->CreateSub(irs.ir->CreateZExt(lb, i32),
irs.ir->CreateZExt(rb, i32));
irs.ir->CreateBr(endBB);

// inc: ++index
irs.ir->SetInsertPoint(incBB);
idx->addIncoming(irs.ir->CreateAdd(idx, DtoConstSize_t(1)), incBB);
irs.ir->CreateBr(condBB);

// end: 0 if the loop ran to completion, else the byte difference
irs.ir->SetInsertPoint(endBB);
llvm::PHINode *result = irs.ir->CreatePHI(i32, 2);
result->addIncoming(DtoConstInt(0), condBB);
result->addIncoming(diff, diffBB);
return result;
}

// Compare two memory blocks for the array lowering: a `memcmp` call on
// host, an inline loop on `@compute` device code (no libc).
llvm::Value *callMemcmp(Loc loc, IRState &irs, LLValue *l_ptr, LLValue *r_ptr,
LLValue *numElements, LLType *elemty) {
assert(l_ptr && r_ptr && numElements);
LLFunction *fn = getRuntimeFunction(loc, gIR->module, "memcmp");
assert(fn);
auto sizeInBytes = numElements;
size_t elementSize = getTypeAllocSize(elemty);
if (elementSize != 1) {
sizeInBytes = irs.ir->CreateMul(sizeInBytes, DtoConstSize_t(elementSize));
}
if (gIR->dcomputetarget)
return emitInlineMemcmp(irs, l_ptr, r_ptr, sizeInBytes);
// Call memcmp.
LLFunction *fn = getRuntimeFunction(loc, gIR->module, "memcmp");
assert(fn);
LLValue *args[] = {l_ptr, r_ptr, sizeInBytes};
return irs.ir->CreateCall(fn, args);
}
Expand Down Expand Up @@ -732,14 +787,17 @@ LLValue *DtoArrayEqCmp_memcmp(Loc loc, DValue *l, DValue *r, IRState &irs) {
// return 0 (equality) when the length is zero.
irs.ir->SetInsertPoint(memcmpBB);
auto memcmpAnswer = callMemcmp(loc, irs, l_ptr, r_ptr, l_length, DtoMemType(l->type->nextOf()));
// callMemcmp may emit extra blocks (the device inline loop), so branch to the
// merge point from whatever block we ended up in.
llvm::BasicBlock *memcmpResultBB = irs.ir->GetInsertBlock();
irs.ir->CreateBr(memcmpEndBB);

// Merge the result of length check and memcmp call into a phi node.
irs.ir->SetInsertPoint(memcmpEndBB);
llvm::PHINode *phi =
irs.ir->CreatePHI(LLType::getInt32Ty(gIR->context()), 2, "cmp_result");
phi->addIncoming(DtoConstInt(1), incomingBB);
phi->addIncoming(memcmpAnswer, memcmpBB);
phi->addIncoming(memcmpAnswer, memcmpResultBB);

return phi;
}
Expand Down
88 changes: 88 additions & 0 deletions tests/codegen/dcompute_array_eq_inline_memcmp.d
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Integral / POD-element array `==` in @compute device code takes the memcmp
// "fast path" (it does NOT instantiate the __equals hook). On the host that path
// lowers to a `memcmp` runtime call, but device targets have no libc `memcmp`
// and there is no `llvm.memcmp` intrinsic, so DtoArrayEqCmp_memcmp/callMemcmp
// must emit the comparison as an INLINE byte-wise loop instead. This test pins
// that lowering down:
// * the inline loop blocks (memcmp.cond/body/diff/inc/end) are emitted,
// * the byte count is numElements * elementSize (and == numElements when the
// element is 1 byte, i.e. no multiply),
// * static arrays skip the runtime length guard and use a constant byte count,
// * the merge phi reads its memcmp result from the loop's real exit block
// (regression guard for the GetInsertBlock() fix: callMemcmp now emits extra
// blocks, so the result no longer comes from the pre-loop block),
// * NO `memcmp` symbol/call and NO __equals hook are emitted for the device.
//
// REQUIRES: target_NVPTX
// RUN: %ldc -mdcompute-targets=cuda-700 -m64 -output-ll -output-o -c \
// RUN: -mdcompute-file-prefix=dcompute_eqinline %s
// RUN: FileCheck %s < dcompute_eqinline_cuda700_64.ll
// RUN: FileCheck %s --check-prefix=NOCALL < dcompute_eqinline_cuda700_64.ll

@compute(CompileFor.deviceOnly) module dcompute_array_eq_inline;
import ldc.dcompute;

// Dynamic int[]: runtime length guard, byte count = len*4, full inline loop,
// and the result phi must take the memcmp answer from %memcmp.end.
// CHECK-LABEL: define{{.*}}ptx_device{{.*}}dyn_int
// CHECK: icmp eq i64
// CHECK: br i1 {{.*}}, label %domemcmp, label %memcmpend
// CHECK: domemcmp:
// CHECK: mul i64 {{.*}}, 4
// CHECK: br label %memcmp.cond
// CHECK: memcmp.cond:
// CHECK: icmp ult i64
// CHECK: br i1 {{.*}}, label %memcmp.body, label %memcmp.end
// CHECK: memcmp.body:
// CHECK: load i8
// CHECK: load i8
// CHECK: icmp eq i8
// CHECK: br i1 {{.*}}, label %memcmp.inc, label %memcmp.diff
// CHECK: memcmp.diff:
// CHECK: zext i8
// CHECK: sub i32
// CHECK: memcmp.inc:
// CHECK: add i64 {{.*}}, 1
// CHECK: memcmp.end:
// CHECK: phi i32 [ 0, %memcmp.cond ]
// CHECK: memcmpend:
// CHECK: phi i32 [ 1, {{.*}} ], [ {{.*}}, %memcmp.end ]
@kernel void dyn_int(int[] a, int[] b, bool* o) { *o = (a == b); }

// Dynamic byte[]: element size 1, so the byte count IS numElements -- there must
// be NO multiply between the length guard and the loop.
// CHECK-LABEL: define{{.*}}ptx_device{{.*}}dyn_byte
// CHECK: domemcmp:
// CHECK-NOT: mul i64
// CHECK: memcmp.cond:
// CHECK: memcmp.body:
// CHECK: load i8
// CHECK: icmp eq i8
@kernel void dyn_byte(byte[] a, byte[] b, bool* o) { *o = (a == b); }

// Dynamic short[]: element size 2 -> byte count = len*2.
// CHECK-LABEL: define{{.*}}ptx_device{{.*}}dyn_short
// CHECK: domemcmp:
// CHECK: mul i64 {{.*}}, 2
// CHECK: memcmp.cond:
@kernel void dyn_short(short[] a, short[] b, bool* o) { *o = (a == b); }

// Static int[4]: lengths are statically equal, so NO runtime length guard and a
// CONSTANT byte count of 16 (4 elements * 4 bytes).
// CHECK-LABEL: define{{.*}}ptx_device{{.*}}stat_int4
// CHECK-NOT: domemcmp
// CHECK: icmp ult i64 {{.*}}, 16
// CHECK: memcmp.body:
// CHECK: load i8
@kernel void stat_int4(int[4] a, int[4] b, bool* o) { *o = (a == b); }

// Static long[3]: constant byte count 24 (3 elements * 8 bytes).
// CHECK-LABEL: define{{.*}}ptx_device{{.*}}stat_long3
// CHECK-NOT: domemcmp
// CHECK: icmp ult i64 {{.*}}, 24
@kernel void stat_long3(long[3] a, long[3] b, bool* o) { *o = (a == b); }

// Device code must contain NO libc memcmp call/declare and must NOT fall back to
// the __equals hook for these integral/POD element types.
// NOCALL-NOT: @memcmp
// NOCALL-NOT: __equals
Loading