Skip to content

Commit fac42a3

Browse files
authored
Dialect Conversion (#20)
* Add poly-to-standard lowering pass shell * Add conversion target * add poly type converter * add empty lowering for poly.add * lower poly.add * add structural conversion helpers for func, etc. * add poly.to_tensor op * lower poly sub, to_tensor, from_tensor ops * lower poly.constant * lower poly.mul as a naive loop * Lower poly.eval * add CMake build files * add a test that chains together many poly ops * demonstrate what a materialization would involve * Revert "demonstrate what a materialization would involve" This reverts commit 1f5dfdd. --------- Co-authored-by: Jeremy Kun <j2kun@users.noreply.github.com>
1 parent e6cfb75 commit fac42a3

16 files changed

Lines changed: 551 additions & 7 deletions

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ externals
1919
/CMakeSettings.json
2020
# Compilation databases
2121
compile_commands.json
22-
tablegen_compile_commands.yml
22+
tablegen_compile_commands.yml

lib/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(Dialect)
2-
add_subdirectory(Transform)
2+
add_subdirectory(Conversion)
3+
add_subdirectory(Transform)

lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
add_subdirectory(PolyToStandard)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
2+
3+
package(
4+
default_visibility = ["//visibility:public"],
5+
)
6+
7+
gentbl_cc_library(
8+
name = "pass_inc_gen",
9+
tbl_outs = [
10+
(
11+
[
12+
"-gen-pass-decls",
13+
"-name=PolyToStandard",
14+
],
15+
"PolyToStandard.h.inc",
16+
),
17+
],
18+
tblgen = "@llvm-project//mlir:mlir-tblgen",
19+
td_file = "PolyToStandard.td",
20+
deps = [
21+
"@llvm-project//mlir:OpBaseTdFiles",
22+
"@llvm-project//mlir:PassBaseTdFiles",
23+
],
24+
)
25+
26+
cc_library(
27+
name = "PolyToStandard",
28+
srcs = ["PolyToStandard.cpp"],
29+
hdrs = ["PolyToStandard.h"],
30+
deps = [
31+
"pass_inc_gen",
32+
"//lib/Dialect/Poly",
33+
"@llvm-project//mlir:ArithDialect",
34+
"@llvm-project//mlir:FuncDialect",
35+
"@llvm-project//mlir:FuncTransforms",
36+
"@llvm-project//mlir:IR",
37+
"@llvm-project//mlir:Pass",
38+
"@llvm-project//mlir:SCFDialect",
39+
"@llvm-project//mlir:TensorDialect",
40+
"@llvm-project//mlir:Transforms",
41+
],
42+
)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
add_mlir_library(PolyToStandard
2+
PolyToStandard.cpp
3+
4+
${PROJECT_SOURCE_DIR}/lib/Conversion/PolyToStandard/
5+
ADDITIONAL_HEADER_DIRS
6+
7+
DEPENDS
8+
PolyToStandardPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRPoly
15+
MLIRArithDialect
16+
MLIRFuncDialect
17+
MLIRFuncTransforms
18+
MLIRIR
19+
MLIRPass
20+
MLIRSCFDialect
21+
MLIRTensorDialect
22+
MLIRTransforms
23+
)
24+
25+
set(LLVM_TARGET_DEFINITIONS PolyToStandard.td)
26+
mlir_tablegen(PolyToStandard.h.inc -gen-pass-decls -name PolyToStandard)
27+
add_dependencies(mlir-headers MLIRPolyOpsIncGen)
28+
add_public_tablegen_target(PolyToStandardPassIncGen)
29+
add_mlir_doc(PolyToStandard PolyToStandard PolyToStandard/ -gen-pass-doc)
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
#include "lib/Conversion/PolyToStandard/PolyToStandard.h"
2+
3+
#include "lib/Dialect/Poly/PolyOps.h"
4+
#include "lib/Dialect/Poly/PolyTypes.h"
5+
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
6+
#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project
7+
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
8+
#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project
9+
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
10+
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
11+
12+
namespace mlir {
13+
namespace tutorial {
14+
namespace poly {
15+
16+
#define GEN_PASS_DEF_POLYTOSTANDARD
17+
#include "lib/Conversion/PolyToStandard/PolyToStandard.h.inc"
18+
19+
class PolyToStandardTypeConverter : public TypeConverter {
20+
public:
21+
PolyToStandardTypeConverter(MLIRContext *ctx) {
22+
addConversion([](Type type) { return type; });
23+
addConversion([ctx](PolynomialType type) -> Type {
24+
int degreeBound = type.getDegreeBound();
25+
IntegerType elementTy =
26+
IntegerType::get(ctx, 32, IntegerType::SignednessSemantics::Signless);
27+
return RankedTensorType::get({degreeBound}, elementTy);
28+
});
29+
30+
// We don't include any custom materialization hooks because this lowering
31+
// is all done in a single pass. The dialect conversion framework works by
32+
// resolving intermediate (mid-pass) type conflicts by inserting
33+
// unrealized_conversion_cast ops, and only converting those to custom
34+
// materializations if they persist at the end of the pass. In our case,
35+
// we'd only need to use custom materializations if we split this lowering
36+
// across multiple passes.
37+
}
38+
};
39+
40+
struct ConvertAdd : public OpConversionPattern<AddOp> {
41+
ConvertAdd(mlir::MLIRContext *context)
42+
: OpConversionPattern<AddOp>(context) {}
43+
44+
using OpConversionPattern::OpConversionPattern;
45+
46+
LogicalResult matchAndRewrite(
47+
AddOp op, OpAdaptor adaptor,
48+
ConversionPatternRewriter &rewriter) const override {
49+
arith::AddIOp addOp = rewriter.create<arith::AddIOp>(
50+
op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
51+
rewriter.replaceOp(op.getOperation(), {addOp});
52+
return success();
53+
}
54+
};
55+
56+
struct ConvertSub : public OpConversionPattern<SubOp> {
57+
ConvertSub(mlir::MLIRContext *context)
58+
: OpConversionPattern<SubOp>(context) {}
59+
60+
using OpConversionPattern::OpConversionPattern;
61+
62+
LogicalResult matchAndRewrite(
63+
SubOp op, OpAdaptor adaptor,
64+
ConversionPatternRewriter &rewriter) const override {
65+
arith::SubIOp subOp = rewriter.create<arith::SubIOp>(
66+
op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
67+
rewriter.replaceOp(op.getOperation(), {subOp});
68+
return success();
69+
}
70+
};
71+
72+
struct ConvertMul : public OpConversionPattern<MulOp> {
73+
ConvertMul(mlir::MLIRContext *context)
74+
: OpConversionPattern<MulOp>(context) {}
75+
76+
using OpConversionPattern::OpConversionPattern;
77+
78+
LogicalResult matchAndRewrite(
79+
MulOp op, OpAdaptor adaptor,
80+
ConversionPatternRewriter &rewriter) const override {
81+
auto polymulTensorType = cast<RankedTensorType>(adaptor.getLhs().getType());
82+
auto numTerms = polymulTensorType.getShape()[0];
83+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
84+
85+
// Create an all-zeros tensor to store the result
86+
auto polymulResult = b.create<arith::ConstantOp>(
87+
polymulTensorType, DenseElementsAttr::get(polymulTensorType, 0));
88+
89+
// Loop bounds and step.
90+
auto lowerBound =
91+
b.create<arith::ConstantOp>(b.getIndexType(), b.getIndexAttr(0));
92+
auto numTermsOp =
93+
b.create<arith::ConstantOp>(b.getIndexType(), b.getIndexAttr(numTerms));
94+
auto step =
95+
b.create<arith::ConstantOp>(b.getIndexType(), b.getIndexAttr(1));
96+
97+
auto p0 = adaptor.getLhs();
98+
auto p1 = adaptor.getRhs();
99+
100+
// for i = 0, ..., N-1
101+
// for j = 0, ..., N-1
102+
// product[i+j (mod N)] += p0[i] * p1[j]
103+
auto outerLoop = b.create<scf::ForOp>(
104+
lowerBound, numTermsOp, step, ValueRange(polymulResult.getResult()),
105+
[&](OpBuilder &builder, Location loc, Value p0Index,
106+
ValueRange loopState) {
107+
ImplicitLocOpBuilder b(op.getLoc(), builder);
108+
auto innerLoop = b.create<scf::ForOp>(
109+
lowerBound, numTermsOp, step, loopState,
110+
[&](OpBuilder &builder, Location loc, Value p1Index,
111+
ValueRange loopState) {
112+
ImplicitLocOpBuilder b(op.getLoc(), builder);
113+
auto accumTensor = loopState.front();
114+
auto destIndex = b.create<arith::RemUIOp>(
115+
b.create<arith::AddIOp>(p0Index, p1Index), numTermsOp);
116+
auto mulOp = b.create<arith::MulIOp>(
117+
b.create<tensor::ExtractOp>(p0, ValueRange(p0Index)),
118+
b.create<tensor::ExtractOp>(p1, ValueRange(p1Index)));
119+
auto result = b.create<arith::AddIOp>(
120+
mulOp, b.create<tensor::ExtractOp>(accumTensor,
121+
destIndex.getResult()));
122+
auto stored = b.create<tensor::InsertOp>(result, accumTensor,
123+
destIndex.getResult());
124+
b.create<scf::YieldOp>(stored.getResult());
125+
});
126+
127+
b.create<scf::YieldOp>(innerLoop.getResults());
128+
});
129+
130+
rewriter.replaceOp(op, outerLoop.getResult(0));
131+
return success();
132+
}
133+
};
134+
135+
struct ConvertEval : public OpConversionPattern<EvalOp> {
136+
ConvertEval(mlir::MLIRContext *context)
137+
: OpConversionPattern<EvalOp>(context) {}
138+
139+
using OpConversionPattern::OpConversionPattern;
140+
141+
LogicalResult matchAndRewrite(
142+
EvalOp op, OpAdaptor adaptor,
143+
ConversionPatternRewriter &rewriter) const override {
144+
auto polyTensorType =
145+
cast<RankedTensorType>(adaptor.getPolynomial().getType());
146+
auto numTerms = polyTensorType.getShape()[0];
147+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
148+
149+
auto lowerBound =
150+
b.create<arith::ConstantOp>(b.getIndexType(), b.getIndexAttr(1));
151+
auto numTermsOp = b.create<arith::ConstantOp>(b.getIndexType(),
152+
b.getIndexAttr(numTerms + 1));
153+
auto step = lowerBound;
154+
155+
auto poly = adaptor.getPolynomial();
156+
auto point = adaptor.getPoint();
157+
158+
// Horner's method:
159+
//
160+
// accum = 0
161+
// for i = 1, 2, ..., N
162+
// accum = point * accum + coeff[N - i]
163+
auto accum =
164+
b.create<arith::ConstantOp>(b.getI32Type(), b.getI32IntegerAttr(0));
165+
auto loop = b.create<scf::ForOp>(
166+
lowerBound, numTermsOp, step, accum.getResult(),
167+
[&](OpBuilder &builder, Location loc, Value loopIndex,
168+
ValueRange loopState) {
169+
ImplicitLocOpBuilder b(op.getLoc(), builder);
170+
auto accum = loopState.front();
171+
auto coeffIndex = b.create<arith::SubIOp>(numTermsOp, loopIndex);
172+
auto mulOp = b.create<arith::MulIOp>(point, accum);
173+
auto result = b.create<arith::AddIOp>(
174+
mulOp, b.create<tensor::ExtractOp>(poly, coeffIndex.getResult()));
175+
b.create<scf::YieldOp>(result.getResult());
176+
});
177+
178+
rewriter.replaceOp(op, loop.getResult(0));
179+
return success();
180+
}
181+
};
182+
183+
struct ConvertFromTensor : public OpConversionPattern<FromTensorOp> {
184+
ConvertFromTensor(mlir::MLIRContext *context)
185+
: OpConversionPattern<FromTensorOp>(context) {}
186+
187+
using OpConversionPattern::OpConversionPattern;
188+
189+
LogicalResult matchAndRewrite(
190+
FromTensorOp op, OpAdaptor adaptor,
191+
ConversionPatternRewriter &rewriter) const override {
192+
auto resultTensorTy = cast<RankedTensorType>(
193+
typeConverter->convertType(op->getResultTypes()[0]));
194+
auto resultShape = resultTensorTy.getShape()[0];
195+
auto resultEltTy = resultTensorTy.getElementType();
196+
197+
auto inputTensorTy = op.getInput().getType();
198+
auto inputShape = inputTensorTy.getShape()[0];
199+
200+
// Zero pad the tensor if the coefficients' size is less than the polynomial
201+
// degree.
202+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
203+
auto coeffValue = adaptor.getInput();
204+
if (inputShape < resultShape) {
205+
SmallVector<OpFoldResult, 1> low, high;
206+
low.push_back(rewriter.getIndexAttr(0));
207+
high.push_back(rewriter.getIndexAttr(resultShape - inputShape));
208+
coeffValue = b.create<tensor::PadOp>(
209+
resultTensorTy, coeffValue, low, high,
210+
b.create<arith::ConstantOp>(rewriter.getIntegerAttr(resultEltTy, 0)),
211+
/*nofold=*/false);
212+
}
213+
214+
rewriter.replaceOp(op, coeffValue);
215+
return success();
216+
}
217+
};
218+
219+
struct ConvertToTensor : public OpConversionPattern<ToTensorOp> {
220+
ConvertToTensor(mlir::MLIRContext *context)
221+
: OpConversionPattern<ToTensorOp>(context) {}
222+
223+
using OpConversionPattern::OpConversionPattern;
224+
225+
LogicalResult matchAndRewrite(
226+
ToTensorOp op, OpAdaptor adaptor,
227+
ConversionPatternRewriter &rewriter) const override {
228+
rewriter.replaceOp(op, adaptor.getInput());
229+
return success();
230+
}
231+
};
232+
233+
struct ConvertConstant : public OpConversionPattern<ConstantOp> {
234+
ConvertConstant(mlir::MLIRContext *context)
235+
: OpConversionPattern<ConstantOp>(context) {}
236+
237+
using OpConversionPattern::OpConversionPattern;
238+
239+
LogicalResult matchAndRewrite(
240+
ConstantOp op, OpAdaptor adaptor,
241+
ConversionPatternRewriter &rewriter) const override {
242+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
243+
auto constOp = b.create<arith::ConstantOp>(adaptor.getCoefficients());
244+
auto fromTensorOp =
245+
b.create<FromTensorOp>(op.getResult().getType(), constOp);
246+
rewriter.replaceOp(op, fromTensorOp.getResult());
247+
return success();
248+
}
249+
};
250+
251+
struct PolyToStandard : impl::PolyToStandardBase<PolyToStandard> {
252+
using PolyToStandardBase::PolyToStandardBase;
253+
254+
void runOnOperation() override {
255+
MLIRContext *context = &getContext();
256+
auto *module = getOperation();
257+
258+
ConversionTarget target(*context);
259+
target.addLegalDialect<arith::ArithDialect>();
260+
target.addIllegalDialect<PolyDialect>();
261+
262+
RewritePatternSet patterns(context);
263+
PolyToStandardTypeConverter typeConverter(context);
264+
patterns.add<ConvertAdd, ConvertConstant, ConvertSub, ConvertEval,
265+
ConvertMul, ConvertFromTensor, ConvertToTensor>(typeConverter,
266+
context);
267+
268+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
269+
patterns, typeConverter);
270+
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
271+
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
272+
typeConverter.isLegal(&op.getBody());
273+
});
274+
275+
populateReturnOpTypeConversionPattern(patterns, typeConverter);
276+
target.addDynamicallyLegalOp<func::ReturnOp>(
277+
[&](func::ReturnOp op) { return typeConverter.isLegal(op); });
278+
279+
populateCallOpTypeConversionPattern(patterns, typeConverter);
280+
target.addDynamicallyLegalOp<func::CallOp>(
281+
[&](func::CallOp op) { return typeConverter.isLegal(op); });
282+
283+
populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
284+
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
285+
return isNotBranchOpInterfaceOrReturnLikeOp(op) ||
286+
isLegalForBranchOpInterfaceTypeConversionPattern(op,
287+
typeConverter) ||
288+
isLegalForReturnOpTypeConversionPattern(op, typeConverter);
289+
});
290+
291+
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {
292+
signalPassFailure();
293+
}
294+
}
295+
};
296+
297+
} // namespace poly
298+
} // namespace tutorial
299+
} // namespace mlir

0 commit comments

Comments
 (0)