diff --git a/include/scalehls-c/Dialect/HLS/HLS.h b/include/scalehls-c/Dialect/HLS/HLS.h index 70c9a39a..fb472feb 100644 --- a/include/scalehls-c/Dialect/HLS/HLS.h +++ b/include/scalehls-c/Dialect/HLS/HLS.h @@ -35,6 +35,12 @@ MLIR_CAPI_EXPORTED MlirType mlirHLSStructTypeGet(MlirStringRef name, MLIR_CAPI_EXPORTED bool mlirTypeIsHLSTypeType(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirHLSTypeTypeGet(MlirContext ctx); +MLIR_CAPI_EXPORTED bool mlirTypeIsHLSIntParamType(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirHLSIntParamTypeGet(MlirContext ctx); + +MLIR_CAPI_EXPORTED bool mlirTypeIsHLSFloatParamType(MlirType type); +MLIR_CAPI_EXPORTED MlirType mlirHLSFloatParamTypeGet(MlirContext ctx); + MLIR_CAPI_EXPORTED bool mlirTypeIsHLSPortType(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirHLSPortTypeGet(MlirContext ctx); diff --git a/include/scalehls/Dialect/HLS/IR/HLSOps.td b/include/scalehls/Dialect/HLS/IR/HLSOps.td index e732812a..20c4091c 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSOps.td +++ b/include/scalehls/Dialect/HLS/IR/HLSOps.td @@ -32,6 +32,9 @@ def AnyBuffer : StaticShapeMemRefOf<[AnyType]>; def AnyStream : StreamOf<[AnyType]>; def AnyBufferOrStream : Type, "memref or stream values">; +def FloatOrIntParamType : Type, + "float or integer parameter types">; //===----------------------------------------------------------------------===// // HLS Operations diff --git a/include/scalehls/Dialect/HLS/IR/HLSTypes.td b/include/scalehls/Dialect/HLS/IR/HLSTypes.td index dc25114a..fc9eb850 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSTypes.td +++ b/include/scalehls/Dialect/HLS/IR/HLSTypes.td @@ -9,8 +9,11 @@ include "scalehls/Dialect/HLS/IR/HLSAttributes.td" -class HLSType traits = []> : - TypeDef; +class HLSType traits = [], + string baseCppClass = "::mlir::Type"> + : TypeDef { + let mnemonic = ?; +} def StreamType : HLSType<"Stream"> { let summary = "An HLS stream type"; @@ -51,6 +54,22 @@ def TypeType : HLSType<"Type"> { let mnemonic = "type"; } +def IntParamType : HLSType<"IntParam", [], "mlir::IntegerType"> { + let summary = "Used to represent an integer parameter"; + let mnemonic = "int_param"; + + let builders = [ + TypeBuilder<(ins "mlir::MLIRContext *":$context), [{ + return $_get(context, mlir::IntegerType::k); + }]> + ]; +} + +def FloatParamType : HLSType<"FloatParam", [], "mlir::FloatType"> { + let summary = "Used to represent a float parameter"; + let mnemonic = "float_param"; +} + def PortType : HLSType<"Port"> { let summary = "Used to represent an input/output"; let mnemonic = "port"; diff --git a/include/scalehls/Dialect/HLS/IR/HLSUIPOps.td b/include/scalehls/Dialect/HLS/IR/HLSUIPOps.td index 0c7e585c..70302a64 100644 --- a/include/scalehls/Dialect/HLS/IR/HLSUIPOps.td +++ b/include/scalehls/Dialect/HLS/IR/HLSUIPOps.td @@ -77,7 +77,7 @@ def PortOp : HLSOp<"uip.port", [Symbol, AttrSizedOperandSegments, HasParent<"DeclareOp">]> { let summary = "Declare a port of an IP"; - let arguments = (ins TypeType:$type, Variadic:$dims, + let arguments = (ins FloatOrIntParamType:$type, Variadic:$dims, Variadic:$symbols, PortKindAttr:$kind, OptionalAttr:$stream_layout, MemRefLayoutAttrInterface:$memory_layout, @@ -87,7 +87,7 @@ def PortOp : HLSOp<"uip.port", [Symbol, AttrSizedOperandSegments, let assemblyFormat = [{ $sym_name $kind `type` $type (`[` $dims^ `]`)? (`(` $symbols^ `)`)? (`stream_layout` $stream_layout^)? `memory_layout` $memory_layout attr-dict - `:` (`[` type($dims)^ `]`)? functional-type($symbols, $result) + `:` type($type) (`[` type($dims)^ `]`)? functional-type($symbols, $result) }]; let extraClassDeclaration = [{ diff --git a/lib/Bindings/Python/HLSDialect.cpp b/lib/Bindings/Python/HLSDialect.cpp index 50716f95..155f10e6 100644 --- a/lib/Bindings/Python/HLSDialect.cpp +++ b/lib/Bindings/Python/HLSDialect.cpp @@ -87,6 +87,26 @@ void populateHLSTypes(py::module &m) { "Get an instance of TypeType in given context.", py::arg("cls"), py::arg("context") = py::none()); + auto intParamType = + mlir_type_subclass(m, "IntParamType", mlirTypeIsHLSIntParamType); + intParamType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirHLSIntParamTypeGet(ctx)); + }, + "Get an instance of IntParamType in given context.", py::arg("cls"), + py::arg("context") = py::none()); + + auto floatParamType = + mlir_type_subclass(m, "FloatParamType", mlirTypeIsHLSFloatParamType); + floatParamType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirHLSFloatParamTypeGet(ctx)); + }, + "Get an instance of FloatParamType in given context.", py::arg("cls"), + py::arg("context") = py::none()); + auto portType = mlir_type_subclass(m, "PortType", mlirTypeIsHLSPortType); portType.def_classmethod( "get", diff --git a/lib/CAPI/Dialect/HLS/HLS.cpp b/lib/CAPI/Dialect/HLS/HLS.cpp index 0ffdf7dc..aeb73fed 100644 --- a/lib/CAPI/Dialect/HLS/HLS.cpp +++ b/lib/CAPI/Dialect/HLS/HLS.cpp @@ -45,6 +45,20 @@ MlirType mlirHLSTypeTypeGet(MlirContext ctx) { return wrap(hls::TypeType::get(unwrap(ctx))); } +bool mlirTypeIsHLSIntParamType(MlirType type) { + return unwrap(type).isa(); +} +MlirType mlirHLSIntParamTypeGet(MlirContext ctx) { + return wrap(hls::IntParamType::get(unwrap(ctx))); +} + +bool mlirTypeIsHLSFloatParamType(MlirType type) { + return unwrap(type).isa(); +} +MlirType mlirHLSFloatParamTypeGet(MlirContext ctx) { + return wrap(hls::FloatParamType::get(unwrap(ctx))); +} + bool mlirTypeIsHLSPortType(MlirType type) { return unwrap(type).isa(); } diff --git a/lib/Dialect/HLS/IR/HLSUIPOps.cpp b/lib/Dialect/HLS/IR/HLSUIPOps.cpp index 999a9e0f..e5c8a684 100644 --- a/lib/Dialect/HLS/IR/HLSUIPOps.cpp +++ b/lib/Dialect/HLS/IR/HLSUIPOps.cpp @@ -132,11 +132,11 @@ void SemanticsOp::initializeBlockArguments( auto port = value.getDefiningOp(); assert(port && port.getKind() != PortKind::PARAM && "invalid port"); if (port.getDims().empty()) - argTypes.push_back(/*port.getType().getType()*/ builder.getF32Type()); + argTypes.push_back(port.getType().getType() /*builder.getF32Type()*/); else argTypes.push_back(RankedTensorType::get( SmallVector(port.getDims().size(), ShapedType::kDynamic), - /*port.getType().getType()*/ builder.getF32Type(), nullptr)); + port.getType().getType() /*builder.getF32Type()*/, nullptr)); argLocs.push_back(port.getLoc()); } diff --git a/python/scalehls/_mlir_libs/_hls_dialect.pyi b/python/scalehls/_mlir_libs/_hls_dialect.pyi index 7eeb231e..73ce55c4 100644 --- a/python/scalehls/_mlir_libs/_hls_dialect.pyi +++ b/python/scalehls/_mlir_libs/_hls_dialect.pyi @@ -2,6 +2,8 @@ from __future__ import annotations import _hls_dialect import typing +from importlib._bootstrap import FloatParamType +from importlib._bootstrap import IntParamType from importlib._bootstrap import MemoryKindType from importlib._bootstrap import ParamKindAttr from importlib._bootstrap import PortKindAttr @@ -11,6 +13,8 @@ from importlib._bootstrap import TaskImplType from importlib._bootstrap import TypeType __all__ = [ + "FloatParamType", + "IntParamType", "MemoryKindType", "ParamKind", "ParamKindAttr", diff --git a/python/scalehls/opdsl/lang/emitter.py b/python/scalehls/opdsl/lang/emitter.py index 390cdff0..52fa4711 100644 --- a/python/scalehls/opdsl/lang/emitter.py +++ b/python/scalehls/opdsl/lang/emitter.py @@ -11,6 +11,7 @@ from ...dialects import math from ...dialects import arith from ...dialects import complex +from ...dialects import hls from ...dialects._ods_common import ( get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, @@ -636,11 +637,12 @@ def _is_floating_point_type(t: Type) -> bool: or F32Type.isinstance(t) or F16Type.isinstance(t) or BF16Type.isinstance(t) + or isinstance(t, hls.FloatParamType) ) def _is_integer_type(t: Type) -> bool: - return IntegerType.isinstance(t) + return (IntegerType.isinstance(t) or isinstance(t, hls.IntParamType)) def _is_index_type(t: Type) -> bool: diff --git a/test/EmitHLSCpp/test-instance.mlir b/test/EmitHLSCpp/test-instance.mlir index 74ea911c..43c039c1 100644 --- a/test/EmitHLSCpp/test-instance.mlir +++ b/test/EmitHLSCpp/test-instance.mlir @@ -8,15 +8,15 @@ module attributes { torch.debug_module_name = "MLP" } { hls.uip.library @testLib { hls.uip.declare @testIp { hls.uip.include ["Path/to/test.hpp"] - %1 = hls.dse.param @template1