Skip to content

LLVM class exam solution - Vlastimil Dort #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ class Compiler : public Visitor {
case ast::BinExp::Type::gt:
result = RUNTIME_CALL(genericGt, lhs, rhs);
return;
case ast::BinExp::Type::dot:
result = RUNTIME_CALL(genericDot, lhs, rhs);
return;
default: // can't happen
return;
}
Expand Down
12 changes: 10 additions & 2 deletions runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,11 +526,19 @@ RVal * c(int size, ...) {
}

double doubleDot(DoubleVector * lhs, DoubleVector * rhs) {
assert(false and "Fill me in");
int productSize = max(lhs->size, rhs->size);
double result = 0;
for (int i = 0; i < productSize; ++i)
result += lhs->data[i % lhs->size] * rhs->data[i % rhs->size];
return result;
}

RVal * genericDot(RVal * lhs, RVal * rhs) {
assert(false and "Fill me in");
if (lhs->type != rhs->type)
throw "Incompatible types for dot product operator";
if (lhs->type != RVal::Type::Double)
throw "Invalid types for dot product operator";
return new RVal(new DoubleVector(doubleDot(lhs->d, rhs->d)));
}


Expand Down
10 changes: 5 additions & 5 deletions tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ namespace rift {
TEST("a = \"aba\" a[c(0,2)]", "aa");
TEST("a = c(1,2,3) a[c(0,1)] = 56 a", 56, 56, 3);

// project2();
// project3();
// project4();
// project5();
// project6();
project2();
project3();
project4();
project5();
project6();
}

} // namespace rift
7 changes: 6 additions & 1 deletion type_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ char TypeAnalysis::ID = 0;

AType * AType::top = createTop();


void TypeAnalysis::genericDot(CallInst * ci) {
// The result of the %*% operator is (if it succeeds) always a single scalar
state.update(ci, new AType(AType::Kind::R, AType::Kind::DV, AType::Kind::D));
}

void TypeAnalysis::genericArithmetic(CallInst * ci) {
AType * lhs = state.get(ci->getOperand(0));
Expand Down Expand Up @@ -100,6 +103,8 @@ bool TypeAnalysis::runOnFunction(llvm::Function & f) {
genericArithmetic(ci);
} else if (s == "genericDiv") {
genericArithmetic(ci);
} else if (s == "genericDot") {
genericDot(ci);
} else if (s == "genericEq") {
genericRelational(ci);
} else if (s == "genericNeq") {
Expand Down
1 change: 1 addition & 0 deletions type_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class TypeAnalysis : public llvm::FunctionPass {

private:
void genericArithmetic(llvm::CallInst * ci);
void genericDot(llvm::CallInst * ci);
void genericRelational(llvm::CallInst * ci);
void genericGetElement(llvm::CallInst * ci);

Expand Down
28 changes: 28 additions & 0 deletions unboxing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,32 @@ bool Unboxing::genericArithmetic(llvm::Instruction::BinaryOps op, llvm::Function
}
}

bool Unboxing::genericDot() {
AType * lhs = state().get(ins->getOperand(0));
AType * rhs = state().get(ins->getOperand(1));
if(lhs->isDouble() and rhs->isDouble()) {
// If both arguments are known to be vectors of double,
// we can unbox them. The result is then double scalar.
llvm::Value *l, *r, *result;
if(lhs->isScalar() and rhs->isScalar()) {
// If both arguments are scalar, use multiplication directly
l = getScalarPayload(lhs);
r = getScalarPayload(rhs);
result = BinaryOperator::Create(Instruction::FMul, l, r, "", ins);
} else {
// Otherwise use doubleDot function call
l = getVectorPayload(lhs);
r = getVectorPayload(rhs);
result = RUNTIME_CALL(m->doubleDot, l, r);
}
AType *result_t = updateAnalysis(result, new AType(AType::Kind::D));
ins->replaceAllUsesWith(box(result_t));
return true;
} else {
return false;
}
}

void Unboxing::doubleRelational(AType * lhs, AType * rhs, llvm::CmpInst::Predicate op, llvm::Function * fop) {
assert(lhs->isDouble() and rhs->isDouble() and "Doubles expected");
AType * result_t;
Expand Down Expand Up @@ -337,6 +363,8 @@ bool Unboxing::runOnFunction(llvm::Function & f) {
erase = genericArithmetic(Instruction::FMul, m->doubleMul);
} else if (s == "genericDiv") {
erase = genericArithmetic(Instruction::FDiv, m->doubleDiv);
} else if (s == "genericDot") {
erase = genericDot();
} else if (s == "genericLt") {
erase = genericRelational(FCmpInst::FCMP_OLT, m->doubleLt);
} else if (s == "genericGt") {
Expand Down
2 changes: 2 additions & 0 deletions unboxing.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class Unboxing : public llvm::FunctionPass {

bool genericAdd();

bool genericDot();

bool genericArithmetic(llvm::Instruction::BinaryOps op, llvm::Function * fop);

void doubleRelational(AType * lhs, AType * rhs, llvm::CmpInst::Predicate op, llvm::Function * fop);
Expand Down