Skip to content

Commit

Permalink
SpecDB: Add OutTensor specs for add.Tensor & add.Scalar
Browse files Browse the repository at this point in the history
Differential Revision: D59402158
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Jul 5, 2024
1 parent 221e7e7 commit 5fb625d
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions specdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,14 @@
),
],
outspec=[
OutArg(ArgType.Tensor),
OutArg(
ArgType.Tensor,
constraints=[
cp.Dtype.Eq(
lambda deps: torch.promote_types(deps[0].dtype, deps[1].dtype)
),
],
)
],
),
Spec(
Expand Down Expand Up @@ -373,7 +380,16 @@
),
],
outspec=[
OutArg(ArgType.Tensor),
OutArg(
ArgType.Tensor,
constraints=[
cp.Dtype.Eq(
lambda deps: (
fn.promote_type_with_scalar(deps[0].dtype, deps[1])
)
),
],
)
],
),
Spec(
Expand Down

0 comments on commit 5fb625d

Please sign in to comment.