Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/marco/Codegen/Lowering/BaseModelica/ModelLowerer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class ModelLowerer : public Lowerer {
llvm::SmallVectorImpl<mlir::bmodelica::VariableOp> &components,
const ast::bmodelica::ClassModification &classModification);

[[nodiscard]] bool
lowerExperimentAnnotation(mlir::bmodelica::ModelOp modelOp,
const ast::bmodelica::Model &model);

protected:
using Lowerer::declare;
using Lowerer::declareVariables;
Expand Down
7 changes: 7 additions & 0 deletions include/marco/Dialect/BaseModelica/IR/BaseModelicaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3460,6 +3460,13 @@ def BaseModelica_ModelOp : BaseModelica_Op<"model",
let hasCanonicalizer = 1;

let extraClassDeclaration = [{
static constexpr llvm::StringLiteral getExperimentStartTimeAttrName() {
return llvm::StringLiteral("experiment.startTime");
}
static constexpr llvm::StringLiteral getExperimentEndTimeAttrName() {
return llvm::StringLiteral("experiment.endTime");
}

static void getCleaningPatterns(
mlir::RewritePatternSet& patterns,
mlir::MLIRContext* context);
Expand Down
14 changes: 14 additions & 0 deletions include/marco/Dialect/Runtime/IR/RuntimeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,20 @@ def Runtime_YieldOp : Runtime_Op<"yield",
];
}

def Runtime_StartTimeOp : Runtime_Op<"start_time",
[HasParent<"mlir::ModuleOp">]>
{
let arguments = (ins OptionalAttr<F64Attr>:$start_time);
let assemblyFormat = "($start_time^)? attr-dict";
}

def Runtime_EndTimeOp : Runtime_Op<"end_time",
[HasParent<"mlir::ModuleOp">]>
{
let arguments = (ins OptionalAttr<F64Attr>:$end_time);
let assemblyFormat = "($end_time^)? attr-dict";
}

def Runtime_ICModelBeginOp : Runtime_Op<"ic_model_begin",
[HasParent<"mlir::ModuleOp">,
SingleBlock,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "marco/Dialect/Runtime/IR/Runtime.h"
#include "marco/VariableFilter/VariableFilter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinAttributes.h"

namespace mlir {
#define GEN_PASS_DEF_BASEMODELICATORUNTIMECONVERSIONPASS
Expand Down Expand Up @@ -193,11 +194,17 @@ BaseModelicaToRuntimeConversionPass::addMissingRuntimeFunctions(
mlir::OpBuilder &builder, mlir::ModuleOp moduleOp) {
mlir::OpBuilder::InsertionGuard guard(builder);

size_t numOfICModelBeginOps = 0, numOfICModelEndOps = 0;
size_t numOfDynamicModelBeginOps = 0, numOfDynamicModelEndOps = 0;
size_t numOfStartTimeOp = 0, numOfEndTimeOp = 0, numOfICModelBeginOps = 0,
numOfICModelEndOps = 0, numOfDynamicModelBeginOps = 0,
numOfDynamicModelEndOps = 0;

std::optional<mlir::FloatAttr> startTime, endTime;
for (auto &op : moduleOp.getOps()) {
if (mlir::isa<mlir::runtime::ICModelBeginOp>(op)) {
if (mlir::isa<mlir::runtime::StartTimeOp>(op)) {
++numOfStartTimeOp;
} else if (mlir::isa<mlir::runtime::EndTimeOp>(op)) {
++numOfEndTimeOp;
} else if (mlir::isa<mlir::runtime::ICModelBeginOp>(op)) {
++numOfICModelBeginOps;
} else if (mlir::isa<mlir::runtime::ICModelEndOp>(op)) {
++numOfICModelEndOps;
Expand All @@ -206,6 +213,30 @@ BaseModelicaToRuntimeConversionPass::addMissingRuntimeFunctions(
} else if (mlir::isa<mlir::runtime::DynamicModelEndOp>(op)) {
++numOfDynamicModelEndOps;
}

// Retrive potential information from the model.
if (auto modelOp = mlir::dyn_cast<ModelOp>(op)) {
if (auto startTimeInner = modelOp->getAttrOfType<mlir::FloatAttr>(
modelOp.getExperimentStartTimeAttrName())) {
startTime = startTimeInner;
}
if (auto endTimeInner = modelOp->getAttrOfType<mlir::FloatAttr>(
modelOp.getExperimentEndTimeAttrName())) {
endTime = endTimeInner;
}
}
}

if (numOfStartTimeOp == 0) {
builder.setInsertionPointToEnd(moduleOp.getBody());
builder.create<mlir::runtime::StartTimeOp>(
moduleOp->getLoc(), startTime ? *startTime : nullptr);
}

if (numOfEndTimeOp == 0) {
builder.setInsertionPointToEnd(moduleOp.getBody());
builder.create<mlir::runtime::EndTimeOp>(moduleOp->getLoc(),
endTime ? *endTime : nullptr);
}

if (numOfICModelBeginOps == 0) {
Expand Down
93 changes: 86 additions & 7 deletions lib/Codegen/Conversion/RuntimeToFunc/RuntimeToFunc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Transforms/DialectConversion.h"
#include <functional>

namespace mlir {
#define GEN_PASS_DEF_RUNTIMETOFUNCCONVERSIONPASS
Expand Down Expand Up @@ -53,6 +54,22 @@ class RuntimeOpRewritePattern : public mlir::OpRewritePattern<Op> {
loc, mlir::LLVM::LLVMPointerType::get(builder.getContext()), type,
globalPtr, llvm::ArrayRef<mlir::LLVM::GEPArg>{0, 0});
}

void createConstantFunc(
mlir::OpBuilder &builder, mlir::Location loc, llvm::StringRef name,
mlir::Type returnType,
std::function<mlir::Value(mlir::OpBuilder &, mlir::Location)>
valueConstructor) const {
mlir::OpBuilder::InsertionGuard guard(builder);

auto funcType = builder.getFunctionType({}, returnType);
auto funcOp = builder.create<mlir::func::FuncOp>(loc, name, funcType);
mlir::Block *entryBlock = funcOp.addEntryBlock();
builder.setInsertionPointToEnd(entryBlock);

auto value = valueConstructor(builder, loc);
builder.create<mlir::func::ReturnOp>(loc, value);
}
};
} // namespace

Expand Down Expand Up @@ -169,6 +186,65 @@ class DeinitFunctionOpLowering
}
};

class StartTimeOpLowering : public RuntimeOpRewritePattern<StartTimeOp> {
public:
using RuntimeOpRewritePattern<StartTimeOp>::RuntimeOpRewritePattern;

mlir::LogicalResult
matchAndRewrite(StartTimeOp op,
mlir::PatternRewriter &rewriter) const override {

const auto isPresent = op.getStartTime().has_value();
createConstantFunc(
rewriter, op->getLoc(), "hasExperimentStartTime", rewriter.getI1Type(),
[&isPresent](auto &builder, auto loc) {
return builder.template create<mlir::arith::ConstantOp>(
loc, builder.getBoolAttr(isPresent));
});

const auto startTime =
isPresent ? op.getStartTime().value().convertToDouble() : 0;
createConstantFunc(
rewriter, op.getLoc(), "getExperimentStartTime", rewriter.getF64Type(),
[&startTime](auto &builder, auto loc) {
return builder.template create<mlir::arith::ConstantOp>(
loc, builder.getF64FloatAttr(startTime));
});

rewriter.eraseOp(op);
return mlir::success();
}
};

class EndTimeOpLowering : public RuntimeOpRewritePattern<EndTimeOp> {
public:
using RuntimeOpRewritePattern<EndTimeOp>::RuntimeOpRewritePattern;

mlir::LogicalResult
matchAndRewrite(EndTimeOp op,
mlir::PatternRewriter &rewriter) const override {
const auto isPresent = op.getEndTime().has_value();
createConstantFunc(
rewriter, op->getLoc(), "hasExperimentEndTime", rewriter.getI1Type(),
[&isPresent](auto &builder, auto loc) {
return builder.template create<mlir::arith::ConstantOp>(
loc, builder.getBoolAttr(isPresent));
});

const auto endTime =
isPresent ? op.getEndTime().value().convertToDouble() : 0;
createConstantFunc(
rewriter, op->getLoc(), "getExperimentEndTime", rewriter.getF64Type(),
[&endTime](auto &builder, auto loc) {
return builder.template create<mlir::arith::ConstantOp>(
loc, builder.getF64FloatAttr(endTime));
});

rewriter.eraseOp(op);
return mlir::success();
}
};

class ICModelBeginOpLowering : public RuntimeOpRewritePattern<ICModelBeginOp> {
public:
using RuntimeOpRewritePattern<ICModelBeginOp>::RuntimeOpRewritePattern;
Expand Down Expand Up @@ -475,8 +551,9 @@ mlir::LogicalResult RuntimeToFuncConversionPass::convertOps() {
mlir::ConversionTarget target(getContext());

target.addIllegalOp<VariableGetterOp, InitFunctionOp, DeinitFunctionOp,
ICModelBeginOp, ICModelEndOp, DynamicModelBeginOp,
DynamicModelEndOp, EquationFunctionOp, ReturnOp>();
StartTimeOp, EndTimeOp, ICModelBeginOp, ICModelEndOp,
DynamicModelBeginOp, DynamicModelEndOp,
EquationFunctionOp, ReturnOp>();

target.addDynamicallyLegalOp<FunctionOp>(
[](FunctionOp op) { return op.isDeclaration(); });
Expand All @@ -486,11 +563,13 @@ mlir::LogicalResult RuntimeToFuncConversionPass::convertOps() {

mlir::RewritePatternSet patterns(&getContext());

patterns.insert<VariableGetterOpLowering, InitFunctionOpLowering,
DeinitFunctionOpLowering, ICModelBeginOpLowering,
ICModelEndOpLowering, DynamicModelBeginOpLowering,
DynamicModelEndOpLowering, EquationFunctionOpLowering,
FunctionOpLowering, ReturnOpLowering>(&getContext());
patterns
.insert<VariableGetterOpLowering, InitFunctionOpLowering,
DeinitFunctionOpLowering, StartTimeOpLowering, EndTimeOpLowering,
ICModelBeginOpLowering, ICModelEndOpLowering,
DynamicModelBeginOpLowering, DynamicModelEndOpLowering,
EquationFunctionOpLowering, FunctionOpLowering, ReturnOpLowering>(
&getContext());

return applyPartialConversion(getOperation(), target, std::move(patterns));
}
Expand Down
86 changes: 86 additions & 0 deletions lib/Codegen/Lowering/ModelLowerer.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "marco/Codegen/Lowering/BaseModelica/ModelLowerer.h"
#include "marco/Dialect/BaseModelica/IR/Ops.h"

using namespace ::marco;
using namespace ::marco::codegen;
Expand Down Expand Up @@ -146,6 +147,11 @@ bool ModelLowerer::lower(const ast::bmodelica::Model &model) {
}
}

// Lower experiment annotation
if (!lowerExperimentAnnotation(modelOp, model)) {
return false;
}

return true;
}

Expand Down Expand Up @@ -269,4 +275,84 @@ bool ModelLowerer::lowerVariableAttributes(

return true;
}

bool ModelLowerer::lowerExperimentAnnotation(
mlir::bmodelica::ModelOp modelOp, const ast::bmodelica::Model &model) {
if (!model.hasAnnotation()) {
return true;
}

const auto *const annotation = model.getAnnotation();
const auto *const classModification = annotation->getProperties();

std::optional<double> startTime, stopTime;
for (const auto &argumentNode : classModification->getArguments()) {
const auto *const elementModification =
argumentNode->cast<ast::bmodelica::Argument>()
->dyn_cast<ast::bmodelica::ElementModification>();

if (!elementModification) {
continue;
}
if (elementModification->getName() != "experiment") {
continue;
}
if (!elementModification->hasModification()) {
continue;
}

const auto *const experimentModification =
elementModification->getModification();
if (!experimentModification->hasClassModification()) {
continue;
}
const auto *const experimentClassModification =
experimentModification->getClassModification();

for (const auto &experimentArgumentNode :
experimentClassModification->getArguments()) {
const auto *const argument =
experimentArgumentNode->cast<ast::bmodelica::Argument>()
->dyn_cast<ast::bmodelica::ElementModification>();

const auto argumentName = argument->getName();
if (argumentName != "StartTime" && argumentName != "StopTime") {
continue;
}

if (!argument->hasModification()) {
continue;
}
const auto *const argumentModification = argument->getModification();

if (!argumentModification->hasExpression()) {
continue;
}
const auto *const constant = argumentModification->getExpression()
->dyn_cast<ast::bmodelica::Constant>();
if (!constant) {
continue;
}

const auto value = constant->as<double>();
if (argumentName == "StartTime") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe emit a warning in case of case-sensitive inequality but case-insensitive match? (not sure you can do it with the current infrastructure, but at least add it as a todo).

startTime = value;
} else if (argumentName == "StopTime") {
stopTime = value;
}
}
}

if (startTime.has_value()) {
modelOp->setAttr(modelOp.getExperimentStartTimeAttrName(),
builder().getF64FloatAttr(startTime.value()));
}
if (stopTime.has_value()) {
modelOp->setAttr(modelOp.getExperimentEndTimeAttrName(),
builder().getF64FloatAttr(stopTime.value()));
}

// TODO: actually handle errors; right now we are *very* permissive
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't be too restrictive on this. The specification of the language is quite permissive by nature. If we have a match on the expected annotations tree and value types, then we emit the attribute. Otherwise (e.g., a string given as start time), just emit a warning rather than an error.

return true;
}
} // namespace marco::codegen::lowering::bmodelica
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// RUN: modelica-opt %s --split-input-file --convert-bmodelica-to-runtime | FileCheck %s

// COM: Model with both start and end time annotations.

// CHECK-LABEL: module {
// CHECK: runtime.start_time 0.000000e+00
// CHECK: runtime.end_time 1.000000e+00

bmodelica.model @Test attributes {experiment.startTime = 0.000000e+00 : f64, experiment.endTime = 1.000000e+00 : f64} {

}

// -----

// COM: Model with only start time annotation.

// CHECK-LABEL: module {
// CHECK: runtime.start_time 5.000000e-01
// CHECK: runtime.end_time

bmodelica.model @Test2 attributes {experiment.startTime = 5.000000e-01 : f64} {

}

// -----

// COM: Model with only end time annotation.

// CHECK-LABEL: module {
// CHECK: runtime.start_time
// CHECK: runtime.end_time 2.000000e+00

bmodelica.model @Test3 attributes {experiment.endTime = 2.000000e+00 : f64} {

}

// -----

// COM: Model without experiment annotations.

// CHECK-LABEL: module {
// CHECK: runtime.start_time
// CHECK: runtime.end_time

bmodelica.model @Test4 {

}
Loading