diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3cca2b7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.swp +build/ diff --git a/compiler.cpp b/compiler.cpp index 606a455..8ae0cc7 100644 --- a/compiler.cpp +++ b/compiler.cpp @@ -390,6 +390,9 @@ class Compiler : public Visitor { case ast::BinExp::Type::gt: result = RUNTIME_CALL(genericGt, lhs, rhs); return; + case ast::BinExp::Type::dot: + result = RUNTIME_CALL(genericDot, lhs, rhs); + return; default: // can't happen return; } diff --git a/runtime.cpp b/runtime.cpp index c1bdd1f..3d51c14 100644 --- a/runtime.cpp +++ b/runtime.cpp @@ -526,11 +526,20 @@ RVal * c(int size, ...) { } double doubleDot(DoubleVector * lhs, DoubleVector * rhs) { - assert(false and "Fill me in"); + int mLen = std::max(lhs->size, rhs->size); + double res = 0; + + for(int i = 0; i < mLen; i++) + { + res += lhs->data[i % lhs->size] * rhs->data[i % rhs->size]; + } + return res; } RVal * genericDot(RVal * lhs, RVal * rhs) { - assert(false and "Fill me in"); + if(lhs->type != RVal::Type::Double || rhs->type != RVal::Type::Double) + throw "Both operands to dot product must be double vectors"; + return new RVal(new DoubleVector(doubleDot(lhs->d, rhs->d))); } diff --git a/tests.cpp b/tests.cpp index fe05cac..3105f55 100644 --- a/tests.cpp +++ b/tests.cpp @@ -86,7 +86,8 @@ namespace rift { TEST("(f = function(a, b) { a %*% b })(c(1,2,3), c(3,2,1))", 10); TEST("(f = function(a, b) { a %*% b })(c(1,2,3), 3)", 18); TEST("(f = function(a, b) { a %*% b })(10, 2)", 20); - TEST("(f = function(a, b) { a %*% b })(c(1,2,3,4), c(5,6))", 5 + 6 * 2 + 3 * 5 + 4 * 6); + TEST("(f = function(a, b) { a %*% b })(c(1,2,3,4), c(5,6))", 1*5 + 2*6 + 3*5 + 4*6); + TEST("(f = function(a, b) { a %*% b })(c(1,2), c(5))", 1*5 + 2*5); } /** Checks that a dot operator result type is correctly set to be double scalar by the type analysis. */ @@ -222,11 +223,11 @@ namespace rift { TEST("a = \"aba\" a[c(0,2)]", "aa"); TEST("a = c(1,2,3) a[c(0,1)] = 56 a", 56, 56, 3); - // project2(); - // project3(); - // project4(); - // project5(); - // project6(); + project2(); + project3(); + project4(); + project5(); + project6(); } } // namespace rift diff --git a/type_analysis.cpp b/type_analysis.cpp index 136e3f7..3821e94 100644 --- a/type_analysis.cpp +++ b/type_analysis.cpp @@ -108,6 +108,9 @@ bool TypeAnalysis::runOnFunction(llvm::Function & f) { genericRelational(ci); } else if (s == "genericGt") { genericRelational(ci); + } else if (s == "genericDot") { + // DV with len 1 is R->DV->D? + state.update(ci, new AType(AType::Kind::R, new AType(AType::Kind::DV, new AType(AType::Kind::D)))); } else if (s == "length") { // result of length operation is always // double scalar diff --git a/unboxing.cpp b/unboxing.cpp index 2b11f32..92d9a17 100644 --- a/unboxing.cpp +++ b/unboxing.cpp @@ -169,6 +169,30 @@ bool Unboxing::genericArithmetic(llvm::Instruction::BinaryOps op, llvm::Function } } +bool Unboxing::genericDot() { + AType *lhs = state().get(ins->getOperand(0)); + AType *rhs = state().get(ins->getOperand(1)); + AType *result_t; + + // both must be doubles + if(not lhs->isDouble() or not rhs->isDouble()) { + return false; + } + + if(lhs->isScalar() and rhs->isScalar()) { + result_t = updateAnalysis( + BinaryOperator::Create(Instruction::FMul, getScalarPayload(lhs), getScalarPayload(rhs), "", ins), + new AType(AType::Kind::D)); + } else { + result_t = updateAnalysis( + RUNTIME_CALL(m->doubleDot, getVectorPayload(lhs), getVectorPayload(rhs)), + new AType(AType::Kind::D)); + } + + ins->replaceAllUsesWith(box(result_t)); + return true; +} + void Unboxing::doubleRelational(AType * lhs, AType * rhs, llvm::CmpInst::Predicate op, llvm::Function * fop) { assert(lhs->isDouble() and rhs->isDouble() and "Doubles expected"); AType * result_t; @@ -337,6 +361,8 @@ bool Unboxing::runOnFunction(llvm::Function & f) { erase = genericArithmetic(Instruction::FMul, m->doubleMul); } else if (s == "genericDiv") { erase = genericArithmetic(Instruction::FDiv, m->doubleDiv); + } else if (s == "genericDot") { + erase = genericDot(); } else if (s == "genericLt") { erase = genericRelational(FCmpInst::FCMP_OLT, m->doubleLt); } else if (s == "genericGt") { diff --git a/unboxing.h b/unboxing.h index bb134c1..16ed2bf 100644 --- a/unboxing.h +++ b/unboxing.h @@ -46,6 +46,8 @@ class Unboxing : public llvm::FunctionPass { void doubleArithmetic(AType * lhs, AType * rhs, llvm::Instruction::BinaryOps op, llvm::Function * fop); + bool genericDot(); + bool genericAdd(); bool genericArithmetic(llvm::Instruction::BinaryOps op, llvm::Function * fop);