|
| 1 | +#include "lib/Conversion/PolyToStandard/PolyToStandard.h" |
| 2 | + |
| 3 | +#include "lib/Dialect/Poly/PolyOps.h" |
| 4 | +#include "lib/Dialect/Poly/PolyTypes.h" |
| 5 | +#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project |
| 6 | +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project |
| 7 | +#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
| 8 | +#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project |
| 9 | +#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project |
| 10 | +#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project |
| 11 | + |
| 12 | +namespace mlir { |
| 13 | +namespace tutorial { |
| 14 | +namespace poly { |
| 15 | + |
| 16 | +#define GEN_PASS_DEF_POLYTOSTANDARD |
| 17 | +#include "lib/Conversion/PolyToStandard/PolyToStandard.h.inc" |
| 18 | + |
| 19 | +class PolyToStandardTypeConverter : public TypeConverter { |
| 20 | + public: |
| 21 | + PolyToStandardTypeConverter(MLIRContext *ctx) { |
| 22 | + addConversion([](Type type) { return type; }); |
| 23 | + addConversion([ctx](PolynomialType type) -> Type { |
| 24 | + int degreeBound = type.getDegreeBound(); |
| 25 | + IntegerType elementTy = |
| 26 | + IntegerType::get(ctx, 32, IntegerType::SignednessSemantics::Signless); |
| 27 | + return RankedTensorType::get({degreeBound}, elementTy); |
| 28 | + }); |
| 29 | + |
| 30 | + // We don't include any custom materialization hooks because this lowering |
| 31 | + // is all done in a single pass. The dialect conversion framework works by |
| 32 | + // resolving intermediate (mid-pass) type conflicts by inserting |
| 33 | + // unrealized_conversion_cast ops, and only converting those to custom |
| 34 | + // materializations if they persist at the end of the pass. In our case, |
| 35 | + // we'd only need to use custom materializations if we split this lowering |
| 36 | + // across multiple passes. |
| 37 | + } |
| 38 | +}; |
| 39 | + |
| 40 | +struct ConvertAdd : public OpConversionPattern<AddOp> { |
| 41 | + ConvertAdd(mlir::MLIRContext *context) |
| 42 | + : OpConversionPattern<AddOp>(context) {} |
| 43 | + |
| 44 | + using OpConversionPattern::OpConversionPattern; |
| 45 | + |
| 46 | + LogicalResult matchAndRewrite( |
| 47 | + AddOp op, OpAdaptor adaptor, |
| 48 | + ConversionPatternRewriter &rewriter) const override { |
| 49 | + arith::AddIOp addOp = rewriter.create<arith::AddIOp>( |
| 50 | + op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); |
| 51 | + rewriter.replaceOp(op.getOperation(), {addOp}); |
| 52 | + return success(); |
| 53 | + } |
| 54 | +}; |
| 55 | + |
| 56 | +struct ConvertSub : public OpConversionPattern<SubOp> { |
| 57 | + ConvertSub(mlir::MLIRContext *context) |
| 58 | + : OpConversionPattern<SubOp>(context) {} |
| 59 | + |
| 60 | + using OpConversionPattern::OpConversionPattern; |
| 61 | + |
| 62 | + LogicalResult matchAndRewrite( |
| 63 | + SubOp op, OpAdaptor adaptor, |
| 64 | + ConversionPatternRewriter &rewriter) const override { |
| 65 | + arith::SubIOp subOp = rewriter.create<arith::SubIOp>( |
| 66 | + op.getLoc(), adaptor.getLhs(), adaptor.getRhs()); |
| 67 | + rewriter.replaceOp(op.getOperation(), {subOp}); |
| 68 | + return success(); |
| 69 | + } |
| 70 | +}; |
| 71 | + |
| 72 | +struct ConvertMul : public OpConversionPattern<MulOp> { |
| 73 | + ConvertMul(mlir::MLIRContext *context) |
| 74 | + : OpConversionPattern<MulOp>(context) {} |
| 75 | + |
| 76 | + using OpConversionPattern::OpConversionPattern; |
| 77 | + |
| 78 | + LogicalResult matchAndRewrite( |
| 79 | + MulOp op, OpAdaptor adaptor, |
| 80 | + ConversionPatternRewriter &rewriter) const override { |
| 81 | + auto polymulTensorType = cast<RankedTensorType>(adaptor.getLhs().getType()); |
| 82 | + auto numTerms = polymulTensorType.getShape()[0]; |
| 83 | + ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
| 84 | + |
| 85 | + // Create an all-zeros tensor to store the result |
| 86 | + auto polymulResult = b.create<arith::ConstantOp>( |
| 87 | + polymulTensorType, DenseElementsAttr::get(polymulTensorType, 0)); |
| 88 | + |
| 89 | + // Loop bounds and step. |
| 90 | + auto lowerBound = |
| 91 | + b.create<arith::ConstantOp>(b.getIndexType(), b.getIndexAttr(0)); |
| 92 | + auto numTermsOp = |
| 93 | + b.create<arith::ConstantOp>(b.getIndexType(), b.getIndexAttr(numTerms)); |
| 94 | + auto step = |
| 95 | + b.create<arith::ConstantOp>(b.getIndexType(), b.getIndexAttr(1)); |
| 96 | + |
| 97 | + auto p0 = adaptor.getLhs(); |
| 98 | + auto p1 = adaptor.getRhs(); |
| 99 | + |
| 100 | + // for i = 0, ..., N-1 |
| 101 | + // for j = 0, ..., N-1 |
| 102 | + // product[i+j (mod N)] += p0[i] * p1[j] |
| 103 | + auto outerLoop = b.create<scf::ForOp>( |
| 104 | + lowerBound, numTermsOp, step, ValueRange(polymulResult.getResult()), |
| 105 | + [&](OpBuilder &builder, Location loc, Value p0Index, |
| 106 | + ValueRange loopState) { |
| 107 | + ImplicitLocOpBuilder b(op.getLoc(), builder); |
| 108 | + auto innerLoop = b.create<scf::ForOp>( |
| 109 | + lowerBound, numTermsOp, step, loopState, |
| 110 | + [&](OpBuilder &builder, Location loc, Value p1Index, |
| 111 | + ValueRange loopState) { |
| 112 | + ImplicitLocOpBuilder b(op.getLoc(), builder); |
| 113 | + auto accumTensor = loopState.front(); |
| 114 | + auto destIndex = b.create<arith::RemUIOp>( |
| 115 | + b.create<arith::AddIOp>(p0Index, p1Index), numTermsOp); |
| 116 | + auto mulOp = b.create<arith::MulIOp>( |
| 117 | + b.create<tensor::ExtractOp>(p0, ValueRange(p0Index)), |
| 118 | + b.create<tensor::ExtractOp>(p1, ValueRange(p1Index))); |
| 119 | + auto result = b.create<arith::AddIOp>( |
| 120 | + mulOp, b.create<tensor::ExtractOp>(accumTensor, |
| 121 | + destIndex.getResult())); |
| 122 | + auto stored = b.create<tensor::InsertOp>(result, accumTensor, |
| 123 | + destIndex.getResult()); |
| 124 | + b.create<scf::YieldOp>(stored.getResult()); |
| 125 | + }); |
| 126 | + |
| 127 | + b.create<scf::YieldOp>(innerLoop.getResults()); |
| 128 | + }); |
| 129 | + |
| 130 | + rewriter.replaceOp(op, outerLoop.getResult(0)); |
| 131 | + return success(); |
| 132 | + } |
| 133 | +}; |
| 134 | + |
| 135 | +struct ConvertEval : public OpConversionPattern<EvalOp> { |
| 136 | + ConvertEval(mlir::MLIRContext *context) |
| 137 | + : OpConversionPattern<EvalOp>(context) {} |
| 138 | + |
| 139 | + using OpConversionPattern::OpConversionPattern; |
| 140 | + |
| 141 | + LogicalResult matchAndRewrite( |
| 142 | + EvalOp op, OpAdaptor adaptor, |
| 143 | + ConversionPatternRewriter &rewriter) const override { |
| 144 | + auto polyTensorType = |
| 145 | + cast<RankedTensorType>(adaptor.getPolynomial().getType()); |
| 146 | + auto numTerms = polyTensorType.getShape()[0]; |
| 147 | + ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
| 148 | + |
| 149 | + auto lowerBound = |
| 150 | + b.create<arith::ConstantOp>(b.getIndexType(), b.getIndexAttr(1)); |
| 151 | + auto numTermsOp = b.create<arith::ConstantOp>(b.getIndexType(), |
| 152 | + b.getIndexAttr(numTerms + 1)); |
| 153 | + auto step = lowerBound; |
| 154 | + |
| 155 | + auto poly = adaptor.getPolynomial(); |
| 156 | + auto point = adaptor.getPoint(); |
| 157 | + |
| 158 | + // Horner's method: |
| 159 | + // |
| 160 | + // accum = 0 |
| 161 | + // for i = 1, 2, ..., N |
| 162 | + // accum = point * accum + coeff[N - i] |
| 163 | + auto accum = |
| 164 | + b.create<arith::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(0)); |
| 165 | + auto loop = b.create<scf::ForOp>( |
| 166 | + lowerBound, numTermsOp, step, accum.getResult(), |
| 167 | + [&](OpBuilder &builder, Location loc, Value loopIndex, |
| 168 | + ValueRange loopState) { |
| 169 | + ImplicitLocOpBuilder b(op.getLoc(), builder); |
| 170 | + auto accum = loopState.front(); |
| 171 | + auto coeffIndex = b.create<arith::SubIOp>(numTermsOp, loopIndex); |
| 172 | + auto mulOp = b.create<arith::MulIOp>(point, accum); |
| 173 | + auto result = b.create<arith::AddIOp>( |
| 174 | + mulOp, b.create<tensor::ExtractOp>(poly, coeffIndex.getResult())); |
| 175 | + b.create<scf::YieldOp>(result.getResult()); |
| 176 | + }); |
| 177 | + |
| 178 | + rewriter.replaceOp(op, loop.getResult(0)); |
| 179 | + return success(); |
| 180 | + } |
| 181 | +}; |
| 182 | + |
| 183 | +struct ConvertFromTensor : public OpConversionPattern<FromTensorOp> { |
| 184 | + ConvertFromTensor(mlir::MLIRContext *context) |
| 185 | + : OpConversionPattern<FromTensorOp>(context) {} |
| 186 | + |
| 187 | + using OpConversionPattern::OpConversionPattern; |
| 188 | + |
| 189 | + LogicalResult matchAndRewrite( |
| 190 | + FromTensorOp op, OpAdaptor adaptor, |
| 191 | + ConversionPatternRewriter &rewriter) const override { |
| 192 | + auto resultTensorTy = cast<RankedTensorType>( |
| 193 | + typeConverter->convertType(op->getResultTypes()[0])); |
| 194 | + auto resultShape = resultTensorTy.getShape()[0]; |
| 195 | + auto resultEltTy = resultTensorTy.getElementType(); |
| 196 | + |
| 197 | + auto inputTensorTy = op.getInput().getType(); |
| 198 | + auto inputShape = inputTensorTy.getShape()[0]; |
| 199 | + |
| 200 | + // Zero pad the tensor if the coefficients' size is less than the polynomial |
| 201 | + // degree. |
| 202 | + ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
| 203 | + auto coeffValue = adaptor.getInput(); |
| 204 | + if (inputShape < resultShape) { |
| 205 | + SmallVector<OpFoldResult, 1> low, high; |
| 206 | + low.push_back(rewriter.getIndexAttr(0)); |
| 207 | + high.push_back(rewriter.getIndexAttr(resultShape - inputShape)); |
| 208 | + coeffValue = b.create<tensor::PadOp>( |
| 209 | + resultTensorTy, coeffValue, low, high, |
| 210 | + b.create<arith::ConstantOp>(rewriter.getIntegerAttr(resultEltTy, 0)), |
| 211 | + /*nofold=*/false); |
| 212 | + } |
| 213 | + |
| 214 | + rewriter.replaceOp(op, coeffValue); |
| 215 | + return success(); |
| 216 | + } |
| 217 | +}; |
| 218 | + |
| 219 | +struct ConvertToTensor : public OpConversionPattern<ToTensorOp> { |
| 220 | + ConvertToTensor(mlir::MLIRContext *context) |
| 221 | + : OpConversionPattern<ToTensorOp>(context) {} |
| 222 | + |
| 223 | + using OpConversionPattern::OpConversionPattern; |
| 224 | + |
| 225 | + LogicalResult matchAndRewrite( |
| 226 | + ToTensorOp op, OpAdaptor adaptor, |
| 227 | + ConversionPatternRewriter &rewriter) const override { |
| 228 | + rewriter.replaceOp(op, adaptor.getInput()); |
| 229 | + return success(); |
| 230 | + } |
| 231 | +}; |
| 232 | + |
| 233 | +struct ConvertConstant : public OpConversionPattern<ConstantOp> { |
| 234 | + ConvertConstant(mlir::MLIRContext *context) |
| 235 | + : OpConversionPattern<ConstantOp>(context) {} |
| 236 | + |
| 237 | + using OpConversionPattern::OpConversionPattern; |
| 238 | + |
| 239 | + LogicalResult matchAndRewrite( |
| 240 | + ConstantOp op, OpAdaptor adaptor, |
| 241 | + ConversionPatternRewriter &rewriter) const override { |
| 242 | + ImplicitLocOpBuilder b(op.getLoc(), rewriter); |
| 243 | + auto constOp = b.create<arith::ConstantOp>(adaptor.getCoefficients()); |
| 244 | + auto fromTensorOp = |
| 245 | + b.create<FromTensorOp>(op.getResult().getType(), constOp); |
| 246 | + rewriter.replaceOp(op, fromTensorOp.getResult()); |
| 247 | + return success(); |
| 248 | + } |
| 249 | +}; |
| 250 | + |
| 251 | +struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> { |
| 252 | + using PolyToStandardBase::PolyToStandardBase; |
| 253 | + |
| 254 | + void runOnOperation() override { |
| 255 | + MLIRContext *context = &getContext(); |
| 256 | + auto *module = getOperation(); |
| 257 | + |
| 258 | + ConversionTarget target(*context); |
| 259 | + target.addLegalDialect<arith::ArithDialect>(); |
| 260 | + target.addIllegalDialect<PolyDialect>(); |
| 261 | + |
| 262 | + RewritePatternSet patterns(context); |
| 263 | + PolyToStandardTypeConverter typeConverter(context); |
| 264 | + patterns.add<ConvertAdd, ConvertConstant, ConvertSub, ConvertEval, |
| 265 | + ConvertMul, ConvertFromTensor, ConvertToTensor>(typeConverter, |
| 266 | + context); |
| 267 | + |
| 268 | + populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>( |
| 269 | + patterns, typeConverter); |
| 270 | + target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { |
| 271 | + return typeConverter.isSignatureLegal(op.getFunctionType()) && |
| 272 | + typeConverter.isLegal(&op.getBody()); |
| 273 | + }); |
| 274 | + |
| 275 | + populateReturnOpTypeConversionPattern(patterns, typeConverter); |
| 276 | + target.addDynamicallyLegalOp<func::ReturnOp>( |
| 277 | + [&](func::ReturnOp op) { return typeConverter.isLegal(op); }); |
| 278 | + |
| 279 | + populateCallOpTypeConversionPattern(patterns, typeConverter); |
| 280 | + target.addDynamicallyLegalOp<func::CallOp>( |
| 281 | + [&](func::CallOp op) { return typeConverter.isLegal(op); }); |
| 282 | + |
| 283 | + populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); |
| 284 | + target.markUnknownOpDynamicallyLegal([&](Operation *op) { |
| 285 | + return isNotBranchOpInterfaceOrReturnLikeOp(op) || |
| 286 | + isLegalForBranchOpInterfaceTypeConversionPattern(op, |
| 287 | + typeConverter) || |
| 288 | + isLegalForReturnOpTypeConversionPattern(op, typeConverter); |
| 289 | + }); |
| 290 | + |
| 291 | + if (failed(applyPartialConversion(module, target, std::move(patterns)))) { |
| 292 | + signalPassFailure(); |
| 293 | + } |
| 294 | + } |
| 295 | +}; |
| 296 | + |
| 297 | +} // namespace poly |
| 298 | +} // namespace tutorial |
| 299 | +} // namespace mlir |
0 commit comments