Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

[WIP] Add bitwise operators and propagate them from parser -> halide -> isl #380

Open
wants to merge 8 commits into
base: add-operators
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tc/core/halide2isl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ using namespace tc::polyhedral::detail;

SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) {
// const Stmt& s) {
// Collect and categorize all the Variable symbols
// Collect and categorize all the Halide Variable symbols as reduction
// or index variables
class BuildSymbolTable : public IRVisitor {
using IRVisitor::visit;
std::set<std::string> included;
Expand All @@ -60,7 +61,7 @@ SymbolTable makeSymbolTable(const tc2halide::HalideComponents& components) {

components.stmt.accept(&builder);
// Get params from components.params which contain everything declared in
// tcdef. However, the 0-D tensors are registered as both params and inputs,
// TC Def. However, the 0-D tensors are registered as both params and inputs,
// filter those out.
for (auto kvp : components.params) {
bool skip = false;
Expand Down Expand Up @@ -200,7 +201,7 @@ std::vector<isl::aff> makeIslAffBoundsFromExpr(
std::vector<isl::aff> result;
// We cannot span multiple constraints if a modulo operation is involved.
// x > max(a,b) % C is not equivalent to (x > a % C && x > b % C).
auto lhs = makeIslAffBoundsFromExpr(space, e, false, false);
auto lhs = makeIslAffBoundsFromExpr(space, op->a, false, false);
CHECK_EQ(lhs.size(), 1u);
if (const int64_t* b = as_const_int(op->b)) {
return {lhs[0].mod(isl::val(space.get_ctx(), *b))};
Expand Down
6 changes: 6 additions & 0 deletions tc/core/libraries.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ template<typename T> inline __device__ T floord(T n, T d) {
return n < 0 ? - (-n + d - 1)/d : n / d;
}
#define if_then_else(cond,a,b) ((cond) ? (a) : (b))
#define shift_left(a,b) ((a) << (b))
#define shift_right(a,b) ((a) >> (b))
#define bitwise_and(a,b) ((a) & (b))
#define bitwise_xor(a,b) ((a) ^ (b))
#define bitwise_or(a,b) ((a) | (b))
#define bitwise_not(a) (~(a))
)C";
} // namespace cpp

Expand Down
107 changes: 65 additions & 42 deletions tc/core/tc2halide.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,43 +62,45 @@ Type translateScalarType(int tcType) {
}
}

// Translate the TC def input params to corresponding Halide components.
// params, inputs will be populated here.
void translateParam(
const lang::Param& p,
map<string, Parameter>* params,
vector<ImageParam>* inputs) {
// Check if the param has already been converted to halide components.
if (params->find(p.ident().name()) != params->end()) {
return;
} else {
lang::TensorType type = p.tensorType();
int dimensions = (int)type.dims().size();
ImageParam imageParam(
translateScalarType(type.scalarType()), dimensions, p.ident().name());
inputs->push_back(imageParam);
vector<Expr> dims;
for (auto d_ : type.dims()) {
if (d_->kind() == lang::TK_IDENT) {
auto d = lang::Ident(d_);
auto it = params->find(d.name());
Parameter p;
if (it != params->end()) {
p = it->second;
} else {
p = Parameter(Int(32), false, 0, d.name(), true);
(*params)[d.name()] = p;
}
dims.push_back(Variable::make(Int(32), p.name(), p));
}
lang::TensorType type = p.tensorType();
int dimensions = (int)type.dims().size();
ImageParam imageParam(
translateScalarType(type.scalarType()), dimensions, p.ident().name());
inputs->push_back(imageParam);
vector<Expr> dims;
for (auto d_ : type.dims()) {
if (d_->kind() == lang::TK_IDENT) {
auto d = lang::Ident(d_);
auto it = params->find(d.name());
Parameter p;
if (it != params->end()) {
p = it->second;
} else {
CHECK(d_->kind() == lang::TK_CONST);
int32_t value = lang::Const(d_).value();
dims.push_back(Expr(value));
p = Parameter(Int(32), false, 0, d.name(), true);
(*params)[d.name()] = p;
}
dims.push_back(Variable::make(Int(32), p.name(), p));
} else {
CHECK(d_->kind() == lang::TK_CONST);
int32_t value = lang::Const(d_).value();
dims.push_back(Expr(value));
}
}

for (int i = 0; i < imageParam.dimensions(); i++) {
imageParam.dim(i).set_bounds(0, dims[i]);
}
(*params)[imageParam.name()] = imageParam.parameter();
for (int i = 0; i < imageParam.dimensions(); i++) {
imageParam.dim(i).set_bounds(0, dims[i]);
}
(*params)[imageParam.name()] = imageParam.parameter();
}

void translateOutput(
Expand Down Expand Up @@ -156,6 +158,8 @@ Expr translateExpr(
return t(0) * t(1);
case '/':
return t(0) / t(1);
case '%':
return t(0) % t(1);
case lang::TK_MIN:
return min(t(0), t(1));
case lang::TK_MAX:
Expand Down Expand Up @@ -186,6 +190,18 @@ Expr translateExpr(
return t(0) && t(1);
case lang::TK_OR:
return t(0) || t(1);
case lang::TK_LS:
return t(0) << t(1);
case lang::TK_RS:
return t(0) >> t(1);
case '|':
return t(0) | t(1);
case '^':
return t(0) ^ t(1);
case '&':
return t(0) & t(1);
case '~':
return ~t(0);
case lang::TK_BUILT_IN: {
auto b = lang::BuiltIn(expr);
vector<Expr> exprs;
Expand Down Expand Up @@ -492,20 +508,22 @@ Expr reductionUpdate(Expr e) {
return Call::make(e.type(), kReductionUpdate, {e}, Call::Intrinsic);
}

// Translate a single TC comprehension/statement to Halide components: funcs,
// bounds, reductions.
void translateComprehension(
const lang::Comprehension& c,
const lang::Comprehension& comprehension,
const map<string, Parameter>& params,
bool throwWarnings,
map<string, Function>* funcs,
FunctionBounds* bounds,
vector<Function>* reductions) {
Function f;
auto it = funcs->find(c.ident().name());
auto it = funcs->find(comprehension.ident().name());
if (it != funcs->end()) {
f = it->second;
} else {
f = Function(c.ident().name());
(*funcs)[c.ident().name()] = f;
f = Function(comprehension.ident().name());
(*funcs)[comprehension.ident().name()] = f;
}
// Function is the internal Halide IR type for a pipeline
// stage. Func is the front-end class that wraps it. Here it's
Expand All @@ -514,7 +532,7 @@ void translateComprehension(

vector<Var> lhs;
vector<Expr> lhs_as_exprs;
for (lang::Ident id : c.indices()) {
for (lang::Ident id : comprehension.indices()) {
lhs.push_back(Var(id.name()));
lhs_as_exprs.push_back(lhs.back());
}
Expand All @@ -523,17 +541,17 @@ void translateComprehension(
// in the future we may consider using Halide Let bindings when they
// are supported later
map<string, Expr> lets;
for (auto wc : c.whereClauses()) {
for (auto wc : comprehension.whereClauses()) {
if (wc->kind() == lang::TK_LET) {
auto let = lang::Let(wc);
lets[let.name().name()] = translateExpr(let.rhs(), params, *funcs, lets);
}
}

Expr rhs = translateExpr(c.rhs(), params, *funcs, lets);
Expr rhs = translateExpr(comprehension.rhs(), params, *funcs, lets);

std::vector<Expr> all_exprs;
for (auto wc : c.whereClauses()) {
for (auto wc : comprehension.whereClauses()) {
if (wc->kind() == lang::TK_EXISTS) {
all_exprs.push_back(
translateExpr(lang::Exists(wc).exp(), params, *funcs, lets));
Expand All @@ -557,7 +575,7 @@ void translateComprehension(
// values (2) +=!, TK_PLUS_EQ_B which first sets the tensor to the identity
// for the reduction and then applies the reduction.
bool should_zero = false;
switch (c.assignment()->kind()) {
switch (comprehension.assignment()->kind()) {
case lang::TK_PLUS_EQ_B:
should_zero = true; // fallthrough
case lang::TK_PLUS_EQ:
Expand Down Expand Up @@ -589,11 +607,12 @@ void translateComprehension(
case '=':
break;
default:
throw lang::ErrorReport(c) << "Unimplemented reduction "
<< c.assignment()->range().text() << "\n";
throw lang::ErrorReport(comprehension)
<< "Unimplemented reduction "
<< comprehension.assignment()->range().text() << "\n";
}

if (c.assignment()->kind() != '=') {
if (comprehension.assignment()->kind() != '=') {
reductions->push_back(f);
}

Expand Down Expand Up @@ -633,7 +652,7 @@ void translateComprehension(
Scope<Interval> solution;

// Put anything explicitly specified with a 'where' class in the solution
for (auto constraint_ : c.whereClauses()) {
for (auto constraint_ : comprehension.whereClauses()) {
if (constraint_->kind() != lang::TK_RANGE_CONSTRAINT)
continue;
auto constraint = lang::RangeConstraint(constraint_);
Expand All @@ -654,7 +673,8 @@ void translateComprehension(

// Infer the rest
all_exprs.push_back(rhs);
forwardBoundsInference(all_exprs, *bounds, c, throwWarnings, &solution);
forwardBoundsInference(
all_exprs, *bounds, comprehension, throwWarnings, &solution);

// TODO: What if subsequent updates have incompatible bounds
// (e.g. an in-place stencil)?. The .bound directive will use the
Expand All @@ -665,7 +685,7 @@ void translateComprehension(

for (Var v : lhs) {
if (!solution.contains(v.name())) {
throw lang::ErrorReport(c)
throw lang::ErrorReport(comprehension)
<< "Free variable " << v
<< " was not solved in range inference. May not be used right-hand side";
}
Expand All @@ -689,7 +709,7 @@ void translateComprehension(
for (size_t i = 0; i < unbound.size(); i++) {
auto v = unbound[unbound.size() - 1 - i];
if (!solution.contains(v->name)) {
throw lang::ErrorReport(c)
throw lang::ErrorReport(comprehension)
<< "Free variable " << v << " is unconstrained. "
<< "Use a 'where' clause to set its range.";
}
Expand Down Expand Up @@ -737,6 +757,7 @@ void translateComprehension(
stage.reorder(loop_nest);
}

// Translate a semantically checked TC def to HalideComponents struct.
HalideComponents translateDef(const lang::Def& def, bool throwWarnings) {
map<string, Function> funcs;
HalideComponents components;
Expand Down Expand Up @@ -956,6 +977,8 @@ translate(isl::ctx ctx, const lang::TreeRef& treeRef, bool throwWarnings) {
lang::Def(lang::Sema().checkFunction(treeRef)), throwWarnings);
}

// NOTE: there is no guarantee here that the tc string has only one def. It
// could have many defs. Only first def will be converted in that case.
HalideComponents
translate(isl::ctx ctx, const std::string& tc, bool throwWarnings) {
LOG_IF(INFO, tc::FLAGS_debug_halide) << tc;
Expand Down
4 changes: 2 additions & 2 deletions tc/core/tc2halide.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace tc2halide {
// of the input and output tensors. We do not explicitly enumerate the
// scalar params.
struct HalideComponents {
lang::TreeRef
def; // post-semantic analaysis tree, used for later error reporting
// post-semantic analaysis tree, used for later error reporting
lang::TreeRef def;
Halide::Internal::Stmt stmt;
std::vector<Halide::ImageParam> inputs;
std::map<std::string, Halide::Internal::Parameter> params;
Expand Down
12 changes: 9 additions & 3 deletions tc/lang/lexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,12 @@ namespace lang {
_(TK_NE, "neq", "!=") \
_(TK_AND, "and", "&&") \
_(TK_OR, "or", "||") \
_(TK_LS, "ls", "<<") \
_(TK_RS, "rs", ">>") \
_(TK_LET, "let", "") \
_(TK_EXISTS, "exists", "exists")

static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!";
static const char* valid_single_char_tokens = "+-*/()[]?:,={}><!%&^|~";

enum TokenKind {
// we use characters to represent themselves so skip all valid characters
Expand Down Expand Up @@ -135,12 +137,16 @@ struct SharedParserData {
{'?'},
{TK_OR},
{TK_AND},
{'|'},
{'^'},
{'&'},
{'>', '<', TK_LE, TK_GE, TK_EQ, TK_NE},
{TK_LS, TK_RS},
{'+', '-'},
{'*', '/'},
{'*', '/', '%'},
};
std::vector<std::vector<int>> unary_ops = {
{'-', '!'},
{'-', '!', '~'},
};

std::stringstream ss;
Expand Down
Loading