diff --git a/tripy/tripy/frontend/ops/tensor_initializers.py b/tripy/tripy/frontend/ops/tensor_initializers.py index da5b1e221..4d21af949 100644 --- a/tripy/tripy/frontend/ops/tensor_initializers.py +++ b/tripy/tripy/frontend/ops/tensor_initializers.py @@ -281,7 +281,7 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor": @constraints.dtypes( constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"}, variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "bool"], + "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], }, ) def arange( @@ -346,7 +346,7 @@ def arange( @constraints.dtypes( constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"}, variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "bool"], + "T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"], }, ) def arange( diff --git a/tripy/tripy/frontend/trace/ops/gather.py b/tripy/tripy/frontend/trace/ops/gather.py index 0df478585..15c66633d 100644 --- a/tripy/tripy/frontend/trace/ops/gather.py +++ b/tripy/tripy/frontend/trace/ops/gather.py @@ -91,7 +91,7 @@ def to_flat_ir(self, inputs, outputs): @constraints.dtypes( constraints={"input": "T1", "index": "T2", constraints.RETURN_VALUE: "T1"}, variables={ - "T1": ["float32", "float16", "bfloat16", "int8", "int32", "bool"], + "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], "T2": ["int32"], }, ) diff --git a/tripy/tripy/frontend/trace/ops/iota.py b/tripy/tripy/frontend/trace/ops/iota.py index a252d756c..98933f92a 100644 --- a/tripy/tripy/frontend/trace/ops/iota.py +++ b/tripy/tripy/frontend/trace/ops/iota.py @@ -69,7 +69,7 @@ def iota_impl(shape: "tripy.Tensor", dim: int, dtype: datatype.dtype, output_ran @constraints.dtypes( constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"}, variables={ - "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "bool"], + "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, ) def iota(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "tripy.Tensor": @@ -101,7 +101,7 @@ def iota(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float3 constraints={"input": "T1", "dtype": "T2", constraints.RETURN_VALUE: "T2"}, variables={ "T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], - "T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "bool"], + "T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"], }, ) def iota_like(input: "tripy.Tensor", dim: int = 0, dtype: Optional[datatype.dtype] = None) -> "tripy.Tensor": diff --git a/tripy/tripy/frontend/trace/ops/reduce.py b/tripy/tripy/frontend/trace/ops/reduce.py index e12a8129a..78353ec4f 100644 --- a/tripy/tripy/frontend/trace/ops/reduce.py +++ b/tripy/tripy/frontend/trace/ops/reduce.py @@ -413,7 +413,7 @@ def _arg_min_max_impl(tensor: "tripy.Tensor", kind: ArgMinMax.Kind, dim: Optiona @export.public_api(document_under="operations/functions") @constraints.dtypes( constraints={"input": "T1", constraints.RETURN_VALUE: "T2"}, - variables={"T1": ["float32", "float16", "bfloat16", "int32", "bool", "int8"], "T2": ["int32"]}, + variables={"T1": ["float32", "float16", "bfloat16", "int32"], "T2": ["int32"]}, ) def argmax(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "tripy.Tensor": """ @@ -445,7 +445,7 @@ def argmax(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = Fal @export.public_api(document_under="operations/functions") @constraints.dtypes( constraints={"input": "T1", constraints.RETURN_VALUE: "T2"}, - variables={"T1": ["float32", "float16", "bfloat16", "int32", "bool", "int8"], "T2": ["int32"]}, + variables={"T1": ["float32", "float16", "bfloat16", "int32"], "T2": ["int32"]}, ) def argmin(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "tripy.Tensor": """