diff --git a/gen/arrays.cpp b/gen/arrays.cpp index 626791854a..26c782cf11 100644 --- a/gen/arrays.cpp +++ b/gen/arrays.cpp @@ -679,18 +679,73 @@ DSliceValue *DtoAppendDCharToUnicodeString(Loc loc, DValue *arr, //////////////////////////////////////////////////////////////////////////////// namespace { -// Create a call instruction to memcmp. -llvm::CallInst *callMemcmp(Loc loc, IRState &irs, LLValue *l_ptr, - LLValue *r_ptr, LLValue *numElements, LLType *elemty) { +// Inline byte-wise memcmp for `@compute` device code. The device has no libc to +// provide `memcmp`, and there is no `llvm.memcmp` intrinsic, so emit the +// comparison as a loop. Returns an i32 with C `memcmp`'s convention: 0 iff the +// `sizeInBytes` bytes are equal, otherwise the (unsigned) difference of the +// first mismatching byte. +llvm::Value *emitInlineMemcmp(IRState &irs, LLValue *l_ptr, LLValue *r_ptr, + LLValue *sizeInBytes) { + LLType *i8 = LLType::getInt8Ty(gIR->context()); + LLType *i32 = LLType::getInt32Ty(gIR->context()); + + llvm::BasicBlock *entryBB = irs.ir->GetInsertBlock(); + llvm::BasicBlock *condBB = irs.insertBB("memcmp.cond"); + llvm::BasicBlock *bodyBB = irs.insertBBAfter(condBB, "memcmp.body"); + llvm::BasicBlock *diffBB = irs.insertBBAfter(bodyBB, "memcmp.diff"); + llvm::BasicBlock *incBB = irs.insertBBAfter(diffBB, "memcmp.inc"); + llvm::BasicBlock *endBB = irs.insertBBAfter(incBB, "memcmp.end"); + irs.ir->CreateBr(condBB); + + // cond: index < sizeInBytes ? + irs.ir->SetInsertPoint(condBB); + llvm::PHINode *idx = irs.ir->CreatePHI(sizeInBytes->getType(), 2); + idx->addIncoming(DtoConstSize_t(0), entryBB); + irs.ir->CreateCondBr( + irs.ir->CreateICmp(llvm::ICmpInst::ICMP_ULT, idx, sizeInBytes), bodyBB, + endBB); + + // body: compare the two bytes at `idx` + irs.ir->SetInsertPoint(bodyBB); + LLValue *lb = irs.ir->CreateLoad(i8, irs.ir->CreateInBoundsGEP(i8, l_ptr, idx)); + LLValue *rb = irs.ir->CreateLoad(i8, irs.ir->CreateInBoundsGEP(i8, r_ptr, idx)); + irs.ir->CreateCondBr(irs.ir->CreateICmp(llvm::ICmpInst::ICMP_EQ, lb, rb), incBB, + diffBB); + + // diff: result = (unsigned)lb - (unsigned)rb + irs.ir->SetInsertPoint(diffBB); + LLValue *diff = irs.ir->CreateSub(irs.ir->CreateZExt(lb, i32), + irs.ir->CreateZExt(rb, i32)); + irs.ir->CreateBr(endBB); + + // inc: ++index + irs.ir->SetInsertPoint(incBB); + idx->addIncoming(irs.ir->CreateAdd(idx, DtoConstSize_t(1)), incBB); + irs.ir->CreateBr(condBB); + + // end: 0 if the loop ran to completion, else the byte difference + irs.ir->SetInsertPoint(endBB); + llvm::PHINode *result = irs.ir->CreatePHI(i32, 2); + result->addIncoming(DtoConstInt(0), condBB); + result->addIncoming(diff, diffBB); + return result; +} + +// Compare two memory blocks for the array lowering: a `memcmp` call on +// host, an inline loop on `@compute` device code (no libc). +llvm::Value *callMemcmp(Loc loc, IRState &irs, LLValue *l_ptr, LLValue *r_ptr, + LLValue *numElements, LLType *elemty) { assert(l_ptr && r_ptr && numElements); - LLFunction *fn = getRuntimeFunction(loc, gIR->module, "memcmp"); - assert(fn); auto sizeInBytes = numElements; size_t elementSize = getTypeAllocSize(elemty); if (elementSize != 1) { sizeInBytes = irs.ir->CreateMul(sizeInBytes, DtoConstSize_t(elementSize)); } + if (gIR->dcomputetarget) + return emitInlineMemcmp(irs, l_ptr, r_ptr, sizeInBytes); // Call memcmp. + LLFunction *fn = getRuntimeFunction(loc, gIR->module, "memcmp"); + assert(fn); LLValue *args[] = {l_ptr, r_ptr, sizeInBytes}; return irs.ir->CreateCall(fn, args); } @@ -732,6 +787,9 @@ LLValue *DtoArrayEqCmp_memcmp(Loc loc, DValue *l, DValue *r, IRState &irs) { // return 0 (equality) when the length is zero. irs.ir->SetInsertPoint(memcmpBB); auto memcmpAnswer = callMemcmp(loc, irs, l_ptr, r_ptr, l_length, DtoMemType(l->type->nextOf())); + // callMemcmp may emit extra blocks (the device inline loop), so branch to the + // merge point from whatever block we ended up in. + llvm::BasicBlock *memcmpResultBB = irs.ir->GetInsertBlock(); irs.ir->CreateBr(memcmpEndBB); // Merge the result of length check and memcmp call into a phi node. @@ -739,7 +797,7 @@ LLValue *DtoArrayEqCmp_memcmp(Loc loc, DValue *l, DValue *r, IRState &irs) { llvm::PHINode *phi = irs.ir->CreatePHI(LLType::getInt32Ty(gIR->context()), 2, "cmp_result"); phi->addIncoming(DtoConstInt(1), incomingBB); - phi->addIncoming(memcmpAnswer, memcmpBB); + phi->addIncoming(memcmpAnswer, memcmpResultBB); return phi; } diff --git a/tests/codegen/dcompute_array_eq_inline_memcmp.d b/tests/codegen/dcompute_array_eq_inline_memcmp.d new file mode 100644 index 0000000000..3cfab2e5ee --- /dev/null +++ b/tests/codegen/dcompute_array_eq_inline_memcmp.d @@ -0,0 +1,88 @@ +// Integral / POD-element array `==` in @compute device code takes the memcmp +// "fast path" (it does NOT instantiate the __equals hook). On the host that path +// lowers to a `memcmp` runtime call, but device targets have no libc `memcmp` +// and there is no `llvm.memcmp` intrinsic, so DtoArrayEqCmp_memcmp/callMemcmp +// must emit the comparison as an INLINE byte-wise loop instead. This test pins +// that lowering down: +// * the inline loop blocks (memcmp.cond/body/diff/inc/end) are emitted, +// * the byte count is numElements * elementSize (and == numElements when the +// element is 1 byte, i.e. no multiply), +// * static arrays skip the runtime length guard and use a constant byte count, +// * the merge phi reads its memcmp result from the loop's real exit block +// (regression guard for the GetInsertBlock() fix: callMemcmp now emits extra +// blocks, so the result no longer comes from the pre-loop block), +// * NO `memcmp` symbol/call and NO __equals hook are emitted for the device. +// +// REQUIRES: target_NVPTX +// RUN: %ldc -mdcompute-targets=cuda-700 -m64 -output-ll -output-o -c \ +// RUN: -mdcompute-file-prefix=dcompute_eqinline %s +// RUN: FileCheck %s < dcompute_eqinline_cuda700_64.ll +// RUN: FileCheck %s --check-prefix=NOCALL < dcompute_eqinline_cuda700_64.ll + +@compute(CompileFor.deviceOnly) module dcompute_array_eq_inline; +import ldc.dcompute; + +// Dynamic int[]: runtime length guard, byte count = len*4, full inline loop, +// and the result phi must take the memcmp answer from %memcmp.end. +// CHECK-LABEL: define{{.*}}ptx_device{{.*}}dyn_int +// CHECK: icmp eq i64 +// CHECK: br i1 {{.*}}, label %domemcmp, label %memcmpend +// CHECK: domemcmp: +// CHECK: mul i64 {{.*}}, 4 +// CHECK: br label %memcmp.cond +// CHECK: memcmp.cond: +// CHECK: icmp ult i64 +// CHECK: br i1 {{.*}}, label %memcmp.body, label %memcmp.end +// CHECK: memcmp.body: +// CHECK: load i8 +// CHECK: load i8 +// CHECK: icmp eq i8 +// CHECK: br i1 {{.*}}, label %memcmp.inc, label %memcmp.diff +// CHECK: memcmp.diff: +// CHECK: zext i8 +// CHECK: sub i32 +// CHECK: memcmp.inc: +// CHECK: add i64 {{.*}}, 1 +// CHECK: memcmp.end: +// CHECK: phi i32 [ 0, %memcmp.cond ] +// CHECK: memcmpend: +// CHECK: phi i32 [ 1, {{.*}} ], [ {{.*}}, %memcmp.end ] +@kernel void dyn_int(int[] a, int[] b, bool* o) { *o = (a == b); } + +// Dynamic byte[]: element size 1, so the byte count IS numElements -- there must +// be NO multiply between the length guard and the loop. +// CHECK-LABEL: define{{.*}}ptx_device{{.*}}dyn_byte +// CHECK: domemcmp: +// CHECK-NOT: mul i64 +// CHECK: memcmp.cond: +// CHECK: memcmp.body: +// CHECK: load i8 +// CHECK: icmp eq i8 +@kernel void dyn_byte(byte[] a, byte[] b, bool* o) { *o = (a == b); } + +// Dynamic short[]: element size 2 -> byte count = len*2. +// CHECK-LABEL: define{{.*}}ptx_device{{.*}}dyn_short +// CHECK: domemcmp: +// CHECK: mul i64 {{.*}}, 2 +// CHECK: memcmp.cond: +@kernel void dyn_short(short[] a, short[] b, bool* o) { *o = (a == b); } + +// Static int[4]: lengths are statically equal, so NO runtime length guard and a +// CONSTANT byte count of 16 (4 elements * 4 bytes). +// CHECK-LABEL: define{{.*}}ptx_device{{.*}}stat_int4 +// CHECK-NOT: domemcmp +// CHECK: icmp ult i64 {{.*}}, 16 +// CHECK: memcmp.body: +// CHECK: load i8 +@kernel void stat_int4(int[4] a, int[4] b, bool* o) { *o = (a == b); } + +// Static long[3]: constant byte count 24 (3 elements * 8 bytes). +// CHECK-LABEL: define{{.*}}ptx_device{{.*}}stat_long3 +// CHECK-NOT: domemcmp +// CHECK: icmp ult i64 {{.*}}, 24 +@kernel void stat_long3(long[3] a, long[3] b, bool* o) { *o = (a == b); } + +// Device code must contain NO libc memcmp call/declare and must NOT fall back to +// the __equals hook for these integral/POD element types. +// NOCALL-NOT: @memcmp +// NOCALL-NOT: __equals