diff --git a/compiler.cpp b/compiler.cpp index 606a455..bed70fd 100644 --- a/compiler.cpp +++ b/compiler.cpp @@ -390,6 +390,9 @@ class Compiler : public Visitor { case ast::BinExp::Type::gt: result = RUNTIME_CALL(genericGt, lhs, rhs); return; + case ast::BinExp::Type::dot: + result = RUNTIME_CALL(genericDot, lhs, rhs); + return; default: // can't happen return; } diff --git a/runtime.cpp b/runtime.cpp index c1bdd1f..0707f5e 100644 --- a/runtime.cpp +++ b/runtime.cpp @@ -526,11 +526,22 @@ RVal * c(int size, ...) { } double doubleDot(DoubleVector * lhs, DoubleVector * rhs) { - assert(false and "Fill me in"); + int maxLength = std::max(lhs->size, rhs->size); + double resolution = 0; + + for(int i = 0; i < maxLength; i++) + { + resolution += lhs->data[i % lhs->size] * rhs->data[i % rhs->size]; + } + + return resolution; } RVal * genericDot(RVal * lhs, RVal * rhs) { - assert(false and "Fill me in"); + if(lhs->type != RVal::Type::Double || rhs->type != RVal::Type::Double) + throw "Operands must be vectors of double"; + + return new RVal(new DoubleVector(doubleDot(lhs->d, rhs->d))); } diff --git a/tests.cpp b/tests.cpp index fe05cac..febf77c 100644 --- a/tests.cpp +++ b/tests.cpp @@ -154,7 +154,7 @@ namespace rift { void tests() { cout << "Running tests..." << endl; - TEST("1", 1); + /* TEST("1", 1); TEST("1 + 2", 3); TEST("1 - 2", -1); TEST("2 * 3", 6); @@ -220,13 +220,17 @@ namespace rift { TEST("a = c(1, 2, 3) a[1]", 2); TEST("a = \"aba\" a[c(0,2)]", "aa"); - TEST("a = c(1,2,3) a[c(0,1)] = 56 a", 56, 56, 3); - - // project2(); - // project3(); - // project4(); - // project5(); - // project6(); + TEST("a = c(1,2,3) a[c(0,1)] = 56 a", 56, 56, 3);*/ + cout << "Running test for task 2" << endl; + project2(); + cout << "Running test for task 3" << endl; + project3(); + cout << "Running test for task 4" << endl; + project4(); + cout << "Running test for task 5" << endl; + project5(); + cout << "Running test for task 6" << endl; + project6(); } } // namespace rift diff --git a/type_analysis.cpp b/type_analysis.cpp index 136e3f7..9342316 100644 --- a/type_analysis.cpp +++ b/type_analysis.cpp @@ -129,7 +129,9 @@ bool TypeAnalysis::runOnFunction(llvm::Function & f) { state.update(ci, new AType(AType::Kind::R)); } else if (s == "envGet") { state.update(ci, new AType(AType::Kind::R)); - } + } else if (s == "genericDot"){ + state.update(ci, new AType(AType::Kind::R, new AType(AType::Kind::DV, new AType(AType::Kind::D)))); + } } else if (PHINode * phi = dyn_cast(&i)) { AType * first = state.get(phi->getOperand(0)); AType * second = state.get(phi->getOperand(1)); diff --git a/unboxing.cpp b/unboxing.cpp index 2b11f32..6f2eaac 100644 --- a/unboxing.cpp +++ b/unboxing.cpp @@ -318,6 +318,27 @@ bool Unboxing::genericEval() { } } +bool Unboxing::genericDot() { + AType *lhs = state().get(ins->getOperand(0)); + AType *rhs = state().get(ins->getOperand(1)); + AType *tmp_result; + + // both must be doubles + if( rhs->isDouble() and lhs->isDouble()) { + if(lhs->isScalar() and rhs->isScalar()) { + tmp_result = updateAnalysis( + BinaryOperator::Create(Instruction::FMul, getScalarPayload(lhs), getScalarPayload(rhs), "", ins), new AType(AType::Kind::D)); + } else { + tmp_result = updateAnalysis( RUNTIME_CALL(m->doubleDot, getVectorPayload(lhs), getVectorPayload(rhs)), new AType(AType::Kind::D)); + } + + ins->replaceAllUsesWith(box(tmp_result)); + return true; + } else { + return false; + } +} + bool Unboxing::runOnFunction(llvm::Function & f) { //std::cout << "running Unboxing optimization..." << std::endl; m = reinterpret_cast(f.getParent()); @@ -353,7 +374,9 @@ bool Unboxing::runOnFunction(llvm::Function & f) { erase = genericC(); } else if (s == "genericEval") { erase = genericEval(); - } + } else if (s == "genericDot") { + erase = genericDot(); + } } if (erase) { llvm::Instruction * v = i; diff --git a/unboxing.h b/unboxing.h index bb134c1..e3fc1ea 100644 --- a/unboxing.h +++ b/unboxing.h @@ -68,6 +68,8 @@ class Unboxing : public llvm::FunctionPass { bool genericEval(); + bool genericDot(); + /** Rift module currently being optimized, obtained from the function. The module is used for the declarations of the runtime functions.