Skip to content
73 changes: 17 additions & 56 deletions src/BoundsInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,52 +189,6 @@ bool is_fused_with_others(const vector<vector<Function>> &fused_groups,
return false;
}

// An inliner that can inline an entire set of functions at once. The inliner in
// Inline.h only handles with one function at a time.
class Inliner : public IRMutator {
public:
std::set<Function, Function::Compare> to_inline;

Expr do_inlining(const Expr &e) {
return common_subexpression_elimination(mutate(e));
}

protected:
std::map<Function, std::map<int, Expr>, Function::Compare> qualified_bodies;

Expr get_qualified_body(const Function &f, int idx) {
auto it = qualified_bodies.find(f);
if (it != qualified_bodies.end()) {
auto it2 = it->second.find(idx);
if (it2 != it->second.end()) {
return it2->second;
}
}
Expr e = qualify(f.name() + ".", f.values()[idx]);
e = do_inlining(e);
qualified_bodies[f][idx] = e;
return e;
}

Expr visit(const Call *op) override {
if (op->func.defined()) {
Function f(op->func);
if (to_inline.count(f)) {
auto args = mutate(op->args);
Expr body = get_qualified_body(f, op->value_index);
const vector<string> &func_args = f.args();
for (size_t i = 0; i < args.size(); i++) {
body = Let::make(f.name() + "." + func_args[i], args[i], body);
}
return body;
}
}
return IRMutator::visit(op);
}

using IRMutator::visit;
};

class BoundsInference : public IRMutator {
public:
const vector<Function> &funcs;
Expand Down Expand Up @@ -686,7 +640,7 @@ class BoundsInference : public IRMutator {
vector<pair<Expr, int>> buffers_to_annotate;
for (const auto &arg : args) {
if (arg.is_expr()) {
bounds_inference_args.push_back(inliner->do_inlining(arg.expr));
bounds_inference_args.push_back((*inliner)(arg.expr));
} else if (arg.is_func()) {
Function input(arg.func);
for (int k = 0; k < input.outputs(); k++) {
Expand Down Expand Up @@ -849,16 +803,23 @@ class BoundsInference : public IRMutator {
// Compute the intrinsic relationships between the stages of
// the functions.

// Figure out which functions will be inlined away
// Figure out which functions will be inlined away.
vector<bool> inlined(f.size());
for (size_t i = 0; i < inlined.size(); i++) {
if (i < f.size() - 1 &&
f[i].schedule().compute_level().is_inlined() &&
f[i].can_be_inlined()) {
inlined[i] = true;
inliner.to_inline.insert(f[i]);
} else {
inlined[i] = false;
inlined[i] = (i < f.size() - 1 &&
f[i].schedule().compute_level().is_inlined() &&
f[i].can_be_inlined());
}
// Register them with the Inliner in consumer-first order. f is in
// realization (producer-first) order, so we iterate backwards: the
// outermost consumer of each chain is added first, the bottom
// producer last. The Inliner's iterative-deepening loop processes
// entries in add() order, so consumers go first -- their materialized
// bodies expose Calls to producers, which the later (deeper) passes
// then substitute. See Inliner's class doc for the full picture.
for (size_t i = inlined.size(); i > 0; i--) {
if (inlined[i - 1]) {
inliner.add(f[i - 1]);
}
}

Expand Down Expand Up @@ -893,7 +854,7 @@ class BoundsInference : public IRMutator {
for (auto &s : stages) {
for (auto &cond_val : s.exprs) {
internal_assert(cond_val.value.defined());
cond_val.value = inliner.do_inlining(cond_val.value);
cond_val.value = inliner(cond_val.value);
}
}

Expand Down
Loading
Loading