Skip to content

Commit

Permalink
SpecDB: Add spec: gather
Browse files Browse the repository at this point in the history
Reviewed By: JacobSzwejbka

Differential Revision: D61822096

fbshipit-source-id: ad9fc4129beaa2a98c9a61d73704f807b5e5a939
  • Loading branch information
manuelcandales authored and facebook-github-bot committed Aug 27, 2024
1 parent a7036d6 commit b8d1403
Showing 1 changed file with 53 additions and 0 deletions.
53 changes: 53 additions & 0 deletions specdb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1943,6 +1943,59 @@
OutArg(ArgType.Tensor),
],
),
Spec( # TODO(mcandales): Calibrate.
op="gather.default", # (Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor
inspec=[
InPosArg(ArgType.Tensor, name="self"),
InPosArg(
ArgType.Dim,
name="dim",
deps=[0],
constraints=[
cp.Value.In(lambda deps: fn.dim_non_zero_size(deps[0])),
],
),
InPosArg(
ArgType.Tensor,
name="index",
deps=[0, 1],
# TODO(mcandales) Handle index.numel() == 0 case
constraints=[
cp.Dtype.Eq(lambda deps: torch.long),
cp.Rank.Eq(
lambda deps: deps[0].dim() if deps[0].dim() >= 2 else None
),
cp.Rank.In(
lambda deps: [0, 1] if deps[0].dim() in [0, 1] else None
),
cp.Size.Le(
lambda deps, r, d: (
fn.safe_size(deps[0], d)
if d != fn.normalize(deps[1], deps[0].dim())
else None
)
),
cp.Value.Ge(lambda deps, dtype, struct: 0),
cp.Value.Le(
lambda deps, dtype, struct: (
0
if deps[0].dim() == 0
else max(0, fn.safe_size(deps[0], deps[1]) - 1)
)
),
],
),
InKwArg(ArgType.Bool, name="sparse_grad"),
],
outspec=[
OutArg(
ArgType.Tensor,
constraints=[
cp.Dtype.Eq(lambda deps: deps[0].dtype),
],
),
],
),
Spec(
op="ge.Scalar", # (Tensor self, Scalar other) -> Tensor
inspec=[
Expand Down

0 comments on commit b8d1403

Please sign in to comment.