From 4c7affbe3b5e17d7be07fca30ea5a923fa408f3a Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Tue, 27 Aug 2024 14:48:26 -0700 Subject: [PATCH] SpecDB: Add spec: scatter.value Reviewed By: JacobSzwejbka Differential Revision: D61874510 fbshipit-source-id: 11622265cc0fda1b75104697b7cc4443fe6aeca5 --- specdb/db.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/specdb/db.py b/specdb/db.py index 1e614ae..57cc845 100644 --- a/specdb/db.py +++ b/specdb/db.py @@ -3478,6 +3478,78 @@ ], outspec=[OutArg(ArgType.Tensor)], ), + Spec( # TODO(mcandales): Calibrate. + op="scatter.value", # (Tensor self, int dim, Tensor index, Scalar value) -> Tensor + inspec=[ + InPosArg(ArgType.Tensor, name="self"), + InPosArg( + ArgType.Dim, + name="dim", + deps=[0], + constraints=[ + cp.Value.In(lambda deps: fn.dim_non_zero_size(deps[0])), + ], + ), + InPosArg( + ArgType.Tensor, + name="index", + deps=[0, 1], + # TODO(mcandales) Handle index.numel() == 0 case + constraints=[ + cp.Dtype.Eq(lambda deps: torch.long), + cp.Rank.Eq( + lambda deps: deps[0].dim() if deps[0].dim() >= 2 else None + ), + cp.Rank.In( + lambda deps: [0, 1] if deps[0].dim() in [0, 1] else None + ), + cp.Size.Le( + lambda deps, r, d: ( + fn.safe_size(deps[0], d) + if d != fn.normalize(deps[1], deps[0].dim()) + else None + ) + ), + cp.Value.Ge(lambda deps, dtype, struct: 0), + cp.Value.Le( + lambda deps, dtype, struct: ( + 0 + if deps[0].dim() == 0 + else max(0, fn.safe_size(deps[0], deps[1]) - 1) + ) + ), + ], + ), + InPosArg( + ArgType.Scalar, + name="value", + deps=[0], + constraints=[ + cp.Value.NotIn( + lambda deps, dtype: ( + [float("-inf"), float("inf")] + if deps[0].dtype not in dt._floating + else None + ) + ), + cp.Value.Ge( + lambda deps, dtype: fn.dtype_lower_bound(deps[0].dtype) + ), + cp.Value.Le( + lambda deps, dtype: fn.dtype_upper_bound(deps[0].dtype) + ), + ], + ), + ], + outspec=[ + OutArg( + ArgType.Tensor, + constraints=[ + cp.Dtype.Eq(lambda deps: deps[0].dtype), + ], + ), + ], + ), Spec( # TODO(mcandales): Calibrate. op="scatter_add.default", # (Tensor self, int dim, Tensor index, Tensor src) -> Tensor inspec=[