Skip to content

Commit

Permalink
[HLS] Add FloatParamType and IntParamType
Browse files Browse the repository at this point in the history
  • Loading branch information
hanchenye committed Nov 14, 2023
1 parent 570ba8c commit cd98fe3
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 23 deletions.
6 changes: 6 additions & 0 deletions include/scalehls-c/Dialect/HLS/HLS.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ MLIR_CAPI_EXPORTED MlirType mlirHLSStructTypeGet(MlirStringRef name,
MLIR_CAPI_EXPORTED bool mlirTypeIsHLSTypeType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirHLSTypeTypeGet(MlirContext ctx);

MLIR_CAPI_EXPORTED bool mlirTypeIsHLSIntParamType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirHLSIntParamTypeGet(MlirContext ctx);

MLIR_CAPI_EXPORTED bool mlirTypeIsHLSFloatParamType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirHLSFloatParamTypeGet(MlirContext ctx);

MLIR_CAPI_EXPORTED bool mlirTypeIsHLSPortType(MlirType type);
MLIR_CAPI_EXPORTED MlirType mlirHLSPortTypeGet(MlirContext ctx);

Expand Down
3 changes: 3 additions & 0 deletions include/scalehls/Dialect/HLS/IR/HLSOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def AnyBuffer : StaticShapeMemRefOf<[AnyType]>;
def AnyStream : StreamOf<[AnyType]>;
def AnyBufferOrStream : Type<Or<[AnyBuffer.predicate, AnyStream.predicate]>,
"memref or stream values">;
def FloatOrIntParamType : Type<Or<[FloatParamType.predicate,
IntParamType.predicate, TypeType.predicate]>,
"float or integer parameter types">;

//===----------------------------------------------------------------------===//
// HLS Operations
Expand Down
23 changes: 21 additions & 2 deletions include/scalehls/Dialect/HLS/IR/HLSTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@

include "scalehls/Dialect/HLS/IR/HLSAttributes.td"

class HLSType<string name, list<Trait> traits = []> :
TypeDef<HLSDialect, name, traits>;
class HLSType<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: TypeDef<HLSDialect, name, traits, baseCppClass> {
let mnemonic = ?;
}

def StreamType : HLSType<"Stream"> {
let summary = "An HLS stream type";
Expand Down Expand Up @@ -51,6 +54,22 @@ def TypeType : HLSType<"Type"> {
let mnemonic = "type";
}

def IntParamType : HLSType<"IntParam", [], "mlir::IntegerType"> {
let summary = "Used to represent an integer parameter";
let mnemonic = "int_param";

let builders = [
TypeBuilder<(ins "mlir::MLIRContext *":$context), [{
return $_get(context, mlir::IntegerType::k);
}]>
];
}

def FloatParamType : HLSType<"FloatParam", [], "mlir::FloatType"> {
let summary = "Used to represent a float parameter";
let mnemonic = "float_param";
}

def PortType : HLSType<"Port"> {
let summary = "Used to represent an input/output";
let mnemonic = "port";
Expand Down
4 changes: 2 additions & 2 deletions include/scalehls/Dialect/HLS/IR/HLSUIPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def PortOp : HLSOp<"uip.port", [Symbol, AttrSizedOperandSegments,
HasParent<"DeclareOp">]> {
let summary = "Declare a port of an IP";

let arguments = (ins TypeType:$type, Variadic<AnyType>:$dims,
let arguments = (ins FloatOrIntParamType:$type, Variadic<AnyType>:$dims,
Variadic<AnyType>:$symbols, PortKindAttr:$kind,
OptionalAttr<MemRefLayoutAttrInterface>:$stream_layout,
MemRefLayoutAttrInterface:$memory_layout,
Expand All @@ -87,7 +87,7 @@ def PortOp : HLSOp<"uip.port", [Symbol, AttrSizedOperandSegments,
let assemblyFormat = [{
$sym_name $kind `type` $type (`[` $dims^ `]`)? (`(` $symbols^ `)`)?
(`stream_layout` $stream_layout^)? `memory_layout` $memory_layout attr-dict
`:` (`[` type($dims)^ `]`)? functional-type($symbols, $result)
`:` type($type) (`[` type($dims)^ `]`)? functional-type($symbols, $result)
}];

let extraClassDeclaration = [{
Expand Down
20 changes: 20 additions & 0 deletions lib/Bindings/Python/HLSDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ void populateHLSTypes(py::module &m) {
"Get an instance of TypeType in given context.", py::arg("cls"),
py::arg("context") = py::none());

auto intParamType =
mlir_type_subclass(m, "IntParamType", mlirTypeIsHLSIntParamType);
intParamType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirHLSIntParamTypeGet(ctx));
},
"Get an instance of IntParamType in given context.", py::arg("cls"),
py::arg("context") = py::none());

auto floatParamType =
mlir_type_subclass(m, "FloatParamType", mlirTypeIsHLSFloatParamType);
floatParamType.def_classmethod(
"get",
[](py::object cls, MlirContext ctx) {
return cls(mlirHLSFloatParamTypeGet(ctx));
},
"Get an instance of FloatParamType in given context.", py::arg("cls"),
py::arg("context") = py::none());

auto portType = mlir_type_subclass(m, "PortType", mlirTypeIsHLSPortType);
portType.def_classmethod(
"get",
Expand Down
14 changes: 14 additions & 0 deletions lib/CAPI/Dialect/HLS/HLS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ MlirType mlirHLSTypeTypeGet(MlirContext ctx) {
return wrap(hls::TypeType::get(unwrap(ctx)));
}

bool mlirTypeIsHLSIntParamType(MlirType type) {
return unwrap(type).isa<hls::IntParamType>();
}
MlirType mlirHLSIntParamTypeGet(MlirContext ctx) {
return wrap(hls::IntParamType::get(unwrap(ctx)));
}

bool mlirTypeIsHLSFloatParamType(MlirType type) {
return unwrap(type).isa<hls::FloatParamType>();
}
MlirType mlirHLSFloatParamTypeGet(MlirContext ctx) {
return wrap(hls::FloatParamType::get(unwrap(ctx)));
}

bool mlirTypeIsHLSPortType(MlirType type) {
return unwrap(type).isa<hls::PortType>();
}
Expand Down
4 changes: 2 additions & 2 deletions lib/Dialect/HLS/IR/HLSUIPOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,11 @@ void SemanticsOp::initializeBlockArguments(
auto port = value.getDefiningOp<PortOp>();
assert(port && port.getKind() != PortKind::PARAM && "invalid port");
if (port.getDims().empty())
argTypes.push_back(/*port.getType().getType()*/ builder.getF32Type());
argTypes.push_back(port.getType().getType() /*builder.getF32Type()*/);
else
argTypes.push_back(RankedTensorType::get(
SmallVector<int64_t>(port.getDims().size(), ShapedType::kDynamic),
/*port.getType().getType()*/ builder.getF32Type(), nullptr));
port.getType().getType() /*builder.getF32Type()*/, nullptr));
argLocs.push_back(port.getLoc());
}

Expand Down
4 changes: 4 additions & 0 deletions python/scalehls/_mlir_libs/_hls_dialect.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from __future__ import annotations
import _hls_dialect
import typing
from importlib._bootstrap import FloatParamType
from importlib._bootstrap import IntParamType
from importlib._bootstrap import MemoryKindType
from importlib._bootstrap import ParamKindAttr
from importlib._bootstrap import PortKindAttr
Expand All @@ -11,6 +13,8 @@ from importlib._bootstrap import TaskImplType
from importlib._bootstrap import TypeType

__all__ = [
"FloatParamType",
"IntParamType",
"MemoryKindType",
"ParamKind",
"ParamKindAttr",
Expand Down
4 changes: 3 additions & 1 deletion python/scalehls/opdsl/lang/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ...dialects import math
from ...dialects import arith
from ...dialects import complex
from ...dialects import hls
from ...dialects._ods_common import (
get_op_result_or_value as _get_op_result_or_value,
get_op_results_or_values as _get_op_results_or_values,
Expand Down Expand Up @@ -636,11 +637,12 @@ def _is_floating_point_type(t: Type) -> bool:
or F32Type.isinstance(t)
or F16Type.isinstance(t)
or BF16Type.isinstance(t)
or isinstance(t, hls.FloatParamType)
)


def _is_integer_type(t: Type) -> bool:
return IntegerType.isinstance(t)
return (IntegerType.isinstance(t) or isinstance(t, hls.IntParamType))


def _is_index_type(t: Type) -> bool:
Expand Down
32 changes: 16 additions & 16 deletions test/EmitHLSCpp/test-instance.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ module attributes { torch.debug_module_name = "MLP" } {
hls.uip.library @testLib {
hls.uip.declare @testIp {
hls.uip.include ["Path/to/test.hpp"]
%1 = hls.dse.param @template1 <template> candidates [f32] : !hls.type
%2 = hls.dse.param @template2 <template> candidates [index] : !hls.type
%1 = hls.dse.param @template1 <template> candidates [f32] : !hls.float_param
%2 = hls.dse.param @template2 <template> candidates [index] : !hls.int_param
%3 = hls.dse.param @template3 <template> candidates [4 : index] : index
%4 = hls.uip.port @para1 <param> type %2 memory_layout #map : () -> !hls.port
%5 = hls.uip.port @para2 <param> type %2 memory_layout #map : () -> !hls.port
%6 = hls.uip.port @input1 <input> type %1 [%4, %5] memory_layout #map1 : [!hls.port, !hls.port] () -> !hls.port
%7 = hls.uip.port @input2 <input> type %1 [%4, %5] memory_layout #map1 : [!hls.port, !hls.port] () -> !hls.port
%8 = hls.uip.port @output1 <output> type %1 [%4, %5] memory_layout #map1 : [!hls.port, !hls.port] () -> !hls.port
hls.uip.semantics<%1, %2, %3> (%4, %5, %6, %7, %8) [2 : index, 3 : index, 4 : index] : <!hls.type, !hls.type, index> (!hls.port, !hls.port, !hls.port, !hls.port, !hls.port) {
%4 = hls.uip.port @para1 <param> type %2 memory_layout #map : !hls.int_param () -> !hls.port
%5 = hls.uip.port @para2 <param> type %2 memory_layout #map : !hls.int_param () -> !hls.port
%6 = hls.uip.port @input1 <input> type %1 [%4, %5] memory_layout #map1 : !hls.float_param [!hls.port, !hls.port] () -> !hls.port
%7 = hls.uip.port @input2 <input> type %1 [%4, %5] memory_layout #map1 : !hls.float_param [!hls.port, !hls.port] () -> !hls.port
%8 = hls.uip.port @output1 <output> type %1 [%4, %5] memory_layout #map1 : !hls.float_param [!hls.port, !hls.port] () -> !hls.port
hls.uip.semantics<%1, %2, %3> (%4, %5, %6, %7, %8) [2 : index, 3 : index, 4 : index] : <!hls.float_param, !hls.int_param, index> (!hls.port, !hls.port, !hls.port, !hls.port, !hls.port) {
^bb0(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>):
%9 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
Expand All @@ -28,15 +28,15 @@ module attributes { torch.debug_module_name = "MLP" } {
}
hls.uip.declare @not_used_Ip {
hls.uip.include ["Path/to/no_used_test.hpp"]
%1 = hls.dse.param @template1 <template> candidates [f32] : !hls.type
%2 = hls.dse.param @template2 <template> candidates [index] : !hls.type
%1 = hls.dse.param @template1 <template> candidates [f32] : !hls.float_param
%2 = hls.dse.param @template2 <template> candidates [index] : !hls.int_param
%3 = hls.dse.param @template3 <template> candidates [4 : index] : index
%4 = hls.uip.port @para1 <param> type %2 memory_layout #map : () -> !hls.port
%5 = hls.uip.port @para2 <param> type %2 memory_layout #map : () -> !hls.port
%6 = hls.uip.port @input1 <input> type %1 [%4, %5] memory_layout #map1 : [!hls.port, !hls.port] () -> !hls.port
%7 = hls.uip.port @input2 <input> type %1 [%4, %5] memory_layout #map1 : [!hls.port, !hls.port] () -> !hls.port
%8 = hls.uip.port @output1 <output> type %1 [%4, %5] memory_layout #map1 : [!hls.port, !hls.port] () -> !hls.port
hls.uip.semantics<%1, %2, %3> (%4, %5, %6, %7, %8) [2 : index, 3 : index, 4 : index] : <!hls.type, !hls.type, index> (!hls.port, !hls.port, !hls.port, !hls.port, !hls.port) {
%4 = hls.uip.port @para1 <param> type %2 memory_layout #map : !hls.int_param () -> !hls.port
%5 = hls.uip.port @para2 <param> type %2 memory_layout #map : !hls.int_param () -> !hls.port
%6 = hls.uip.port @input1 <input> type %1 [%4, %5] memory_layout #map1 : !hls.float_param [!hls.port, !hls.port] () -> !hls.port
%7 = hls.uip.port @input2 <input> type %1 [%4, %5] memory_layout #map1 : !hls.float_param [!hls.port, !hls.port] () -> !hls.port
%8 = hls.uip.port @output1 <output> type %1 [%4, %5] memory_layout #map1 : !hls.float_param [!hls.port, !hls.port] () -> !hls.port
hls.uip.semantics<%1, %2, %3> (%4, %5, %6, %7, %8) [2 : index, 3 : index, 4 : index] : <!hls.float_param, !hls.int_param, index> (!hls.port, !hls.port, !hls.port, !hls.port, !hls.port) {
^bb0(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>):
%9 = linalg.generic {indexing_maps = [#map1, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<?x?xf32>) outs(%arg1 : tensor<?x?xf32>) {
^bb0(%in: f32, %out: f32):
Expand Down

0 comments on commit cd98fe3

Please sign in to comment.