Skip to content

My solution. #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ class Compiler : public Visitor {
case ast::BinExp::Type::div:
result = RUNTIME_CALL(genericDiv, lhs, rhs);
return;
case ast::BinExp::Type::dot:
result = RUNTIME_CALL(genericDot, lhs, rhs);
return;
case ast::BinExp::Type::eq:
result = RUNTIME_CALL(genericEq, lhs, rhs);
return;
Expand Down
24 changes: 22 additions & 2 deletions runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -526,11 +526,31 @@ RVal * c(int size, ...) {
}

double doubleDot(DoubleVector * lhs, DoubleVector * rhs) {
assert(false and "Fill me in");
double result = 0;
int vectorSize = max(lhs->size, rhs->size);
if (lhs->size == 1 || rhs->size == 1) {
double scalar;
double *doubleVector;
if (lhs->size == 1) {
scalar = lhs->data[0];
doubleVector = rhs->data;
} else {
scalar = rhs->data[0];
doubleVector = lhs->data;
}
for (int i = 0; i < vectorSize; ++i)
result += scalar * doubleVector[i % vectorSize];
} else {
for (int i = 0; i < vectorSize; ++i)
result += lhs->data[i % lhs->size] * rhs->data[i % rhs->size];
}
return result;
}

RVal * genericDot(RVal * lhs, RVal * rhs) {
assert(false and "Fill me in");
if (lhs->type != RVal::Type::Double || rhs->type != RVal::Type::Double)
throw "Both operands of dot product have to be double vectors";
return new RVal(new DoubleVector(doubleDot(lhs->d, rhs->d)));
}


Expand Down
10 changes: 5 additions & 5 deletions tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ namespace rift {
TEST("a = \"aba\" a[c(0,2)]", "aa");
TEST("a = c(1,2,3) a[c(0,1)] = 56 a", 56, 56, 3);

// project2();
// project3();
// project4();
// project5();
// project6();
project2();
project3();
project4();
project5();
project6();
}

} // namespace rift
6 changes: 6 additions & 0 deletions type_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ void TypeAnalysis::genericArithmetic(CallInst * ci) {
state.update(ci, lhs->merge(rhs));
}

void TypeAnalysis::genericDot(CallInst * ci) {
state.update(ci, new AType(AType::Kind::R, AType::Kind::DV, AType::Kind::D));
}

void TypeAnalysis::genericRelational(CallInst * ci) {
AType * lhs = state.get(ci->getOperand(0));
AType * rhs = state.get(ci->getOperand(1));
Expand Down Expand Up @@ -100,6 +104,8 @@ bool TypeAnalysis::runOnFunction(llvm::Function & f) {
genericArithmetic(ci);
} else if (s == "genericDiv") {
genericArithmetic(ci);
} else if (s == "genericDot") {
genericDot(ci);
} else if (s == "genericEq") {
genericRelational(ci);
} else if (s == "genericNeq") {
Expand Down
1 change: 1 addition & 0 deletions type_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class TypeAnalysis : public llvm::FunctionPass {

private:
void genericArithmetic(llvm::CallInst * ci);
void genericDot(llvm::CallInst * ci);
void genericRelational(llvm::CallInst * ci);
void genericGetElement(llvm::CallInst * ci);

Expand Down
21 changes: 21 additions & 0 deletions unboxing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,25 @@ bool Unboxing::genericAdd() {
}
}

bool Unboxing::genericDot() {
AType * lhs = state().get(ins->getOperand(0));
AType * rhs = state().get(ins->getOperand(1));
if (!lhs->isDouble() || !rhs->isDouble())
return false;
AType * result_t;
if (lhs->isScalar() and rhs->isScalar()) {
result_t = updateAnalysis(
BinaryOperator::Create(Instruction::FMul, getScalarPayload(lhs), getScalarPayload(rhs), "", ins),
new AType(AType::Kind::D));
} else {
result_t = updateAnalysis(
RUNTIME_CALL(m->doubleDot, getVectorPayload(lhs), getVectorPayload(rhs)),
new AType(AType::Kind::D));
}
ins->replaceAllUsesWith(box(result_t));
return true;
}

bool Unboxing::genericArithmetic(llvm::Instruction::BinaryOps op, llvm::Function * fop) {
AType * lhs = state().get(ins->getOperand(0));
AType * rhs = state().get(ins->getOperand(1));
Expand Down Expand Up @@ -337,6 +356,8 @@ bool Unboxing::runOnFunction(llvm::Function & f) {
erase = genericArithmetic(Instruction::FMul, m->doubleMul);
} else if (s == "genericDiv") {
erase = genericArithmetic(Instruction::FDiv, m->doubleDiv);
} else if (s == "genericDot") {
erase = genericDot();
} else if (s == "genericLt") {
erase = genericRelational(FCmpInst::FCMP_OLT, m->doubleLt);
} else if (s == "genericGt") {
Expand Down
2 changes: 2 additions & 0 deletions unboxing.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class Unboxing : public llvm::FunctionPass {

bool genericAdd();

bool genericDot();

bool genericArithmetic(llvm::Instruction::BinaryOps op, llvm::Function * fop);

void doubleRelational(AType * lhs, AType * rhs, llvm::CmpInst::Predicate op, llvm::Function * fop);
Expand Down