diff --git a/compiler.cpp b/compiler.cpp index 606a455..bcc27b2 100644 --- a/compiler.cpp +++ b/compiler.cpp @@ -390,6 +390,9 @@ class Compiler : public Visitor { case ast::BinExp::Type::gt: result = RUNTIME_CALL(genericGt, lhs, rhs); return; + case ast::BinExp::Type::dot: + result = RUNTIME_CALL(genericDot, lhs, rhs); + return; default: // can't happen return; } diff --git a/runtime.cpp b/runtime.cpp index c1bdd1f..59595d8 100644 --- a/runtime.cpp +++ b/runtime.cpp @@ -526,11 +526,19 @@ RVal * c(int size, ...) { } double doubleDot(DoubleVector * lhs, DoubleVector * rhs) { - assert(false and "Fill me in"); + int productSize = max(lhs->size, rhs->size); + double result = 0; + for (int i = 0; i < productSize; ++i) + result += lhs->data[i % lhs->size] * rhs->data[i % rhs->size]; + return result; } RVal * genericDot(RVal * lhs, RVal * rhs) { - assert(false and "Fill me in"); + if (lhs->type != rhs->type) + throw "Incompatible types for dot product operator"; + if (lhs->type != RVal::Type::Double) + throw "Invalid types for dot product operator"; + return new RVal(new DoubleVector(doubleDot(lhs->d, rhs->d))); } diff --git a/tests.cpp b/tests.cpp index fe05cac..f347ccb 100644 --- a/tests.cpp +++ b/tests.cpp @@ -222,11 +222,11 @@ namespace rift { TEST("a = \"aba\" a[c(0,2)]", "aa"); TEST("a = c(1,2,3) a[c(0,1)] = 56 a", 56, 56, 3); - // project2(); - // project3(); - // project4(); - // project5(); - // project6(); + project2(); + project3(); + project4(); + project5(); + project6(); } } // namespace rift diff --git a/type_analysis.cpp b/type_analysis.cpp index 136e3f7..2684189 100644 --- a/type_analysis.cpp +++ b/type_analysis.cpp @@ -14,7 +14,10 @@ char TypeAnalysis::ID = 0; AType * AType::top = createTop(); - +void TypeAnalysis::genericDot(CallInst * ci) { + // The result of the %*% operator is (if it succeeds) always a single scalar + state.update(ci, new AType(AType::Kind::R, AType::Kind::DV, AType::Kind::D)); +} void TypeAnalysis::genericArithmetic(CallInst * ci) { AType * lhs = state.get(ci->getOperand(0)); @@ -100,6 +103,8 @@ bool TypeAnalysis::runOnFunction(llvm::Function & f) { genericArithmetic(ci); } else if (s == "genericDiv") { genericArithmetic(ci); + } else if (s == "genericDot") { + genericDot(ci); } else if (s == "genericEq") { genericRelational(ci); } else if (s == "genericNeq") { diff --git a/type_analysis.h b/type_analysis.h index 79f3d7e..55cb23c 100644 --- a/type_analysis.h +++ b/type_analysis.h @@ -218,6 +218,7 @@ class TypeAnalysis : public llvm::FunctionPass { private: void genericArithmetic(llvm::CallInst * ci); + void genericDot(llvm::CallInst * ci); void genericRelational(llvm::CallInst * ci); void genericGetElement(llvm::CallInst * ci); diff --git a/unboxing.cpp b/unboxing.cpp index 2b11f32..35ca9be 100644 --- a/unboxing.cpp +++ b/unboxing.cpp @@ -169,6 +169,32 @@ bool Unboxing::genericArithmetic(llvm::Instruction::BinaryOps op, llvm::Function } } +bool Unboxing::genericDot() { + AType * lhs = state().get(ins->getOperand(0)); + AType * rhs = state().get(ins->getOperand(1)); + if(lhs->isDouble() and rhs->isDouble()) { + // If both arguments are known to be vectors of double, + // we can unbox them. The result is then double scalar. + llvm::Value *l, *r, *result; + if(lhs->isScalar() and rhs->isScalar()) { + // If both arguments are scalar, use multiplication directly + l = getScalarPayload(lhs); + r = getScalarPayload(rhs); + result = BinaryOperator::Create(Instruction::FMul, l, r, "", ins); + } else { + // Otherwise use doubleDot function call + l = getVectorPayload(lhs); + r = getVectorPayload(rhs); + result = RUNTIME_CALL(m->doubleDot, l, r); + } + AType *result_t = updateAnalysis(result, new AType(AType::Kind::D)); + ins->replaceAllUsesWith(box(result_t)); + return true; + } else { + return false; + } +} + void Unboxing::doubleRelational(AType * lhs, AType * rhs, llvm::CmpInst::Predicate op, llvm::Function * fop) { assert(lhs->isDouble() and rhs->isDouble() and "Doubles expected"); AType * result_t; @@ -337,6 +363,8 @@ bool Unboxing::runOnFunction(llvm::Function & f) { erase = genericArithmetic(Instruction::FMul, m->doubleMul); } else if (s == "genericDiv") { erase = genericArithmetic(Instruction::FDiv, m->doubleDiv); + } else if (s == "genericDot") { + erase = genericDot(); } else if (s == "genericLt") { erase = genericRelational(FCmpInst::FCMP_OLT, m->doubleLt); } else if (s == "genericGt") { diff --git a/unboxing.h b/unboxing.h index bb134c1..cf70169 100644 --- a/unboxing.h +++ b/unboxing.h @@ -48,6 +48,8 @@ class Unboxing : public llvm::FunctionPass { bool genericAdd(); + bool genericDot(); + bool genericArithmetic(llvm::Instruction::BinaryOps op, llvm::Function * fop); void doubleRelational(AType * lhs, AType * rhs, llvm::CmpInst::Predicate op, llvm::Function * fop);