Skip to content

Commit

Permalink
Fix type constraints tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jhalakpatel committed Nov 6, 2024
1 parent e037aff commit 1914775
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/ops/tensor_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor":
@constraints.dtypes(
constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"},
variables={
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "bool"],
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"],
},
)
def arange(
Expand Down Expand Up @@ -346,7 +346,7 @@ def arange(
@constraints.dtypes(
constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"},
variables={
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "bool"],
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "int64", "bool"],
},
)
def arange(
Expand Down
2 changes: 1 addition & 1 deletion tripy/tripy/frontend/trace/ops/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def to_flat_ir(self, inputs, outputs):
@constraints.dtypes(
constraints={"input": "T1", "index": "T2", constraints.RETURN_VALUE: "T1"},
variables={
"T1": ["float32", "float16", "bfloat16", "int8", "int32", "bool"],
"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"],
"T2": ["int32"],
},
)
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/trace/ops/iota.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def iota_impl(shape: "tripy.Tensor", dim: int, dtype: datatype.dtype, output_ran
@constraints.dtypes(
constraints={"dtype": "T1", constraints.RETURN_VALUE: "T1"},
variables={
"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "bool"],
"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"],
},
)
def iota(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float32) -> "tripy.Tensor":
Expand Down Expand Up @@ -101,7 +101,7 @@ def iota(shape: ShapeLike, dim: int = 0, dtype: datatype.dtype = datatype.float3
constraints={"input": "T1", "dtype": "T2", constraints.RETURN_VALUE: "T2"},
variables={
"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"],
"T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "bool"],
"T2": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"],
},
)
def iota_like(input: "tripy.Tensor", dim: int = 0, dtype: Optional[datatype.dtype] = None) -> "tripy.Tensor":
Expand Down
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/trace/ops/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def _arg_min_max_impl(tensor: "tripy.Tensor", kind: ArgMinMax.Kind, dim: Optiona
@export.public_api(document_under="operations/functions")
@constraints.dtypes(
constraints={"input": "T1", constraints.RETURN_VALUE: "T2"},
variables={"T1": ["float32", "float16", "bfloat16", "int32", "bool", "int8"], "T2": ["int32"]},
variables={"T1": ["float32", "float16", "bfloat16", "int32"], "T2": ["int32"]},
)
def argmax(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "tripy.Tensor":
"""
Expand Down Expand Up @@ -445,7 +445,7 @@ def argmax(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = Fal
@export.public_api(document_under="operations/functions")
@constraints.dtypes(
constraints={"input": "T1", constraints.RETURN_VALUE: "T2"},
variables={"T1": ["float32", "float16", "bfloat16", "int32", "bool", "int8"], "T2": ["int32"]},
variables={"T1": ["float32", "float16", "bfloat16", "int32"], "T2": ["int32"]},
)
def argmin(input: "tripy.Tensor", dim: Optional[int] = None, keepdim: bool = False) -> "tripy.Tensor":
"""
Expand Down

0 comments on commit 1914775

Please sign in to comment.