diff --git a/specdb/db.py b/specdb/db.py index aa1dce6..5cd03d3 100644 --- a/specdb/db.py +++ b/specdb/db.py @@ -341,7 +341,16 @@ ), ], outspec=[ - OutArg(ArgType.Tensor), + OutArg( + ArgType.Tensor, + constraints=[ + cp.Dtype.In( + lambda deps: dt.can_cast_from( + torch.promote_types(deps[0].dtype, deps[1].dtype) + ) + ), + ], + ), ], ), Spec( @@ -373,7 +382,16 @@ ), ], outspec=[ - OutArg(ArgType.Tensor), + OutArg( + ArgType.Tensor, + constraints=[ + cp.Dtype.Eq( + lambda deps: ( + fn.promote_type_with_scalar(deps[0].dtype, deps[1]) + ) + ), + ], + ) ], ), Spec(