Skip to content

Commit bfa957d

Browse files
sterliakovilevkivskyihauntsaninja
authored
Support expand_type with ParamSpec.{args,kwargs} (#20119)
Fixes #19839. Looks like it was relatively easy to do the right way, let me try! When splitting a callable/parameters into args and kwargs, we have the following options: * posonly - only goes to `*args`, required unless has a default. If we encounter such required arg, all previously collected optional args become required (this only happens due to faulty TVT expansion somewhere; probably I should look into that too) * kwonly - only goes to `**kwargs`, required unless has a default * pos-or-kw - goes to both * vararg - only goes to `*args` as an `Unpack` (possibly normalized by tuple constructor) * kwargs - only goes to `**kwargs` and is only used if there are no kwargs with known names, because PEP 728 is not yet implemented, so we have to choose between `dict` and `TypedDict`. (thoughts? Maybe it is better to prefer `dict` with `union(kwarg, *kwargs.values())` as value type? Either way I do not consider this question important as PEP728 will be eventually implemented, and we'll have `extra_items` for our `TypedDict`s) Applying these steps to every argument in order, we collect required and optional args and kwargs candidates. Now, the type of `**kwargs` is a `TypedDict` if we know any keys, `dict[str, KwargType]` if we only have something like `**kw: str`, and `dict[str, Never]` if no kwargs were found. The type of `*args` is union of all prefixes of `optional_args` concatenated with `required_args`: all required args must be there, and optional args can only be passed in order. Since it is uncommon to have a function with more than 10-20 args, I think this union is a reasonable solution. --------- Co-authored-by: Ivan Levkivskyi <[email protected]> Co-authored-by: Shantanu <[email protected]>
1 parent 5bba0c3 commit bfa957d

File tree

2 files changed

+242
-9
lines changed

2 files changed

+242
-9
lines changed

mypy/expandtype.py

Lines changed: 85 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Iterable, Mapping
44
from typing import Final, TypeVar, cast, overload
55

6-
from mypy.nodes import ARG_STAR, FakeInfo, Var
6+
from mypy.nodes import ARG_STAR, ArgKind, FakeInfo, Var
77
from mypy.state import state
88
from mypy.types import (
99
ANY_STRATEGY,
@@ -270,19 +270,95 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
270270
),
271271
)
272272
elif isinstance(repl, Parameters):
273-
assert t.flavor == ParamSpecFlavor.BARE
274-
return Parameters(
275-
self.expand_types(t.prefix.arg_types) + repl.arg_types,
276-
t.prefix.arg_kinds + repl.arg_kinds,
277-
t.prefix.arg_names + repl.arg_names,
278-
variables=[*t.prefix.variables, *repl.variables],
279-
imprecise_arg_kinds=repl.imprecise_arg_kinds,
280-
)
273+
assert isinstance(t.upper_bound, ProperType) and isinstance(t.upper_bound, Instance)
274+
if t.flavor == ParamSpecFlavor.BARE:
275+
return Parameters(
276+
self.expand_types(t.prefix.arg_types) + repl.arg_types,
277+
t.prefix.arg_kinds + repl.arg_kinds,
278+
t.prefix.arg_names + repl.arg_names,
279+
variables=[*t.prefix.variables, *repl.variables],
280+
imprecise_arg_kinds=repl.imprecise_arg_kinds,
281+
)
282+
elif t.flavor == ParamSpecFlavor.ARGS:
283+
assert all(k.is_positional() for k in t.prefix.arg_kinds)
284+
return self._possible_callable_varargs(
285+
repl, list(t.prefix.arg_types), t.upper_bound
286+
)
287+
else:
288+
assert t.flavor == ParamSpecFlavor.KWARGS
289+
return self._possible_callable_kwargs(repl, t.upper_bound)
281290
else:
282291
# We could encode Any as trivial parameters etc., but it would be too verbose.
283292
# TODO: assert this is a trivial type, like Any, Never, or object.
284293
return repl
285294

295+
@classmethod
296+
def _possible_callable_varargs(
297+
cls, repl: Parameters, required_prefix: list[Type], tuple_type: Instance
298+
) -> ProperType:
299+
"""Given a callable, extract all parameters that can be passed as `*args`.
300+
301+
This builds a union of all (possibly variadic) tuples representing all possible
302+
argument sequences that can be passed positionally. Each such tuple starts with
303+
all required (pos-only without a default) arguments, followed by some prefix
304+
of other arguments that can be passed positionally.
305+
"""
306+
required_posargs = required_prefix
307+
if repl.variables:
308+
# We will tear the callable apart, do not leak type variables
309+
return tuple_type
310+
optional_posargs: list[Type] = []
311+
for kind, name, type in zip(repl.arg_kinds, repl.arg_names, repl.arg_types):
312+
if kind == ArgKind.ARG_POS and name is None:
313+
if optional_posargs:
314+
# May happen following Unpack expansion without kinds correction
315+
required_posargs += optional_posargs
316+
optional_posargs = []
317+
required_posargs.append(type)
318+
elif kind.is_positional():
319+
optional_posargs.append(type)
320+
elif kind == ArgKind.ARG_STAR:
321+
if isinstance(type, UnpackType):
322+
optional_posargs.append(type)
323+
else:
324+
optional_posargs.append(UnpackType(Instance(tuple_type.type, [type])))
325+
break
326+
return UnionType.make_union(
327+
[
328+
TupleType(required_posargs + optional_posargs[:i], fallback=tuple_type)
329+
for i in range(len(optional_posargs) + 1)
330+
]
331+
)
332+
333+
@classmethod
334+
def _possible_callable_kwargs(cls, repl: Parameters, dict_type: Instance) -> ProperType:
335+
"""Given a callable, extract all parameters that can be passed as `**kwargs`.
336+
337+
If the function only accepts **kwargs, this will be a `dict[str, KwargsValueType]`.
338+
Otherwise, this will be a `TypedDict` containing all explicit args and ignoring
339+
`**kwargs` (until PEP 728 `extra_items` is supported). TypedDict entries will
340+
be required iff the corresponding argument is kw-only and has no default.
341+
"""
342+
if repl.variables:
343+
# We will tear the callable apart, do not leak type variables
344+
return dict_type
345+
kwargs = {}
346+
required_names = set()
347+
extra_items: Type = UninhabitedType()
348+
for kind, name, type in zip(repl.arg_kinds, repl.arg_names, repl.arg_types):
349+
if kind == ArgKind.ARG_NAMED and name is not None:
350+
kwargs[name] = type
351+
required_names.add(name)
352+
elif kind == ArgKind.ARG_STAR2:
353+
# Unpack[TypedDict] is normalized early, it isn't stored as Unpack
354+
extra_items = type
355+
elif not kind.is_star() and name is not None:
356+
kwargs[name] = type
357+
if not kwargs:
358+
return Instance(dict_type.type, [dict_type.args[0], extra_items])
359+
# TODO: when PEP 728 is implemented, pass extra_items below.
360+
return TypedDictType(kwargs, required_names, set(), fallback=dict_type)
361+
286362
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
287363
# Sometimes solver may need to expand a type variable with (a copy of) itself
288364
# (usually together with other TypeVars, but it is hard to filter out TypeVarTuples).

test-data/unit/check-parameter-specification.test

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2599,3 +2599,160 @@ def run3(predicate: Callable[Concatenate[int, str, _P], None], *args: _P.args, *
25992599
# E: Argument 1 has incompatible type "*tuple[int | str, ...]"; expected "str" \
26002600
# E: Argument 1 has incompatible type "*tuple[int | str, ...]"; expected "_P.args"
26012601
[builtins fixtures/paramspec.pyi]
2602+
2603+
[case testRevealBoundParamSpecArgs]
2604+
from typing import Callable, Generic, ParamSpec
2605+
from typing_extensions import Concatenate, TypeVarTuple, Unpack
2606+
2607+
P = ParamSpec("P")
2608+
Ts = TypeVarTuple("Ts")
2609+
2610+
class Sneaky(Generic[P]):
2611+
def __init__(self, fn: Callable[P, object], *args: P.args, **kwargs: P.kwargs) -> None:
2612+
self.fn = fn
2613+
self.args = args
2614+
self.kwargs = kwargs
2615+
2616+
class SneakyPrefix(Generic[P]):
2617+
def __init__(self, fn: Callable[Concatenate[int, P], object], _: int, *args: P.args, **kwargs: P.kwargs) -> None:
2618+
self.fn = fn
2619+
self.args = args
2620+
self.kwargs = kwargs
2621+
2622+
def f1() -> int:
2623+
return 0
2624+
def f2(x: int) -> int:
2625+
return 0
2626+
def f3(x: int, /) -> int:
2627+
return 0
2628+
def f4(*, x: int) -> int:
2629+
return 0
2630+
def f5(x: int, y: int = 0) -> int:
2631+
return 0
2632+
def f6(x: int, *args: int) -> int:
2633+
return 0
2634+
def f7(x: int, *args: Unpack[Ts]) -> int:
2635+
return 0
2636+
def f8(x: int, *args: Unpack[tuple[str, ...]]) -> int:
2637+
return 0
2638+
def f9(x: int, *args: Unpack[tuple[str, int]]) -> int:
2639+
return 0
2640+
def f10(x: int=0, *args: Unpack[tuple[str, ...]]) -> int:
2641+
return 0
2642+
def f11(x: int = 0, /) -> int:
2643+
return 0
2644+
2645+
reveal_type(Sneaky(f1).args) # N: Revealed type is "tuple[()]"
2646+
reveal_type(SneakyPrefix(f1).args) # E: Missing positional argument "_" in call to "SneakyPrefix" \
2647+
# N: Revealed type is "tuple[()]" \
2648+
# E: Argument 1 to "SneakyPrefix" has incompatible type "Callable[[], int]"; expected "Callable[[int], object]"
2649+
2650+
reveal_type(Sneaky(f2, 1).args) # N: Revealed type is "tuple[()] | tuple[builtins.int]"
2651+
reveal_type(SneakyPrefix(f2, 1).args) # N: Revealed type is "tuple[()]"
2652+
2653+
reveal_type(Sneaky(f3, 1).args) # N: Revealed type is "tuple[builtins.int]"
2654+
reveal_type(SneakyPrefix(f3, 1).args) # N: Revealed type is "tuple[()]"
2655+
2656+
reveal_type(Sneaky(f4, x=1).args) # N: Revealed type is "tuple[()]"
2657+
2658+
reveal_type(Sneaky(f5, 1).args) # N: Revealed type is "tuple[()] | tuple[builtins.int] | tuple[builtins.int, builtins.int]"
2659+
reveal_type(SneakyPrefix(f5, 1).args) # N: Revealed type is "tuple[()] | tuple[builtins.int]"
2660+
reveal_type(Sneaky(f5, 1, 2).args) # N: Revealed type is "tuple[()] | tuple[builtins.int] | tuple[builtins.int, builtins.int]"
2661+
reveal_type(SneakyPrefix(f5, 1, 2).args) # N: Revealed type is "tuple[()] | tuple[builtins.int]"
2662+
2663+
reveal_type(Sneaky(f6, 1).args) # N: Revealed type is "tuple[()] | tuple[builtins.int] | tuple[builtins.int, Unpack[builtins.tuple[builtins.int, ...]]]"
2664+
reveal_type(SneakyPrefix(f6, 1).args) # N: Revealed type is "tuple[()] | tuple[Unpack[builtins.tuple[builtins.int, ...]]]"
2665+
reveal_type(Sneaky(f6, 1, 2).args) # N: Revealed type is "tuple[()] | tuple[builtins.int] | tuple[builtins.int, Unpack[builtins.tuple[builtins.int, ...]]]"
2666+
reveal_type(SneakyPrefix(f6, 1, 2).args) # N: Revealed type is "tuple[()] | tuple[Unpack[builtins.tuple[builtins.int, ...]]]"
2667+
2668+
reveal_type(Sneaky(f7, 1, 2).args) # N: Revealed type is "tuple[Literal[1]?, Literal[2]?]"
2669+
reveal_type(SneakyPrefix(f7, 1, 2).args) # N: Revealed type is "tuple[Literal[2]?]"
2670+
2671+
reveal_type(Sneaky(f8, 1, '').args) # N: Revealed type is "tuple[()] | tuple[builtins.int] | tuple[builtins.int, Unpack[builtins.tuple[builtins.str, ...]]]"
2672+
reveal_type(SneakyPrefix(f8, 1, '').args) # N: Revealed type is "tuple[()] | tuple[Unpack[builtins.tuple[builtins.str, ...]]]"
2673+
2674+
reveal_type(Sneaky(f9, 1, '', 0).args) # N: Revealed type is "tuple[builtins.int, builtins.str, builtins.int]"
2675+
reveal_type(SneakyPrefix(f9, 1, '', 0).args) # N: Revealed type is "tuple[builtins.str, builtins.int]"
2676+
2677+
reveal_type(Sneaky(f10, 1, '', '').args) # N: Revealed type is "tuple[()] | tuple[builtins.int] | tuple[builtins.int, Unpack[builtins.tuple[builtins.str, ...]]]"
2678+
reveal_type(SneakyPrefix(f10, 1, '', '').args) # N: Revealed type is "tuple[()] | tuple[Unpack[builtins.tuple[builtins.str, ...]]]"
2679+
2680+
reveal_type(Sneaky(f11).args) # N: Revealed type is "tuple[()] | tuple[builtins.int]"
2681+
[builtins fixtures/paramspec.pyi]
2682+
2683+
[case testRevealBoundParamSpecGeneric]
2684+
from typing import Callable, Generic, ParamSpec, TypeVar
2685+
from typing_extensions import TypeVarTuple, Unpack
2686+
2687+
T = TypeVar("T")
2688+
P = ParamSpec("P")
2689+
Ts = TypeVarTuple("Ts")
2690+
2691+
class SplitSneaky(Generic[P]):
2692+
def __init__(self, target: Callable[P, None]) -> None:
2693+
...
2694+
2695+
def run(self, *args: P.args, **kwargs: P.kwargs) -> None:
2696+
self.args = args
2697+
self.kwargs = kwargs
2698+
2699+
def f1(x: T) -> None: ...
2700+
def f2(*xs: Unpack[Ts]) -> None: ...
2701+
def f3(fn: Callable[P, None]) -> None: ...
2702+
2703+
reveal_type(SplitSneaky(f1).args) # N: Revealed type is "builtins.tuple[builtins.object, ...]"
2704+
reveal_type(SplitSneaky(f1).kwargs) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
2705+
reveal_type(SplitSneaky(f2).args) # N: Revealed type is "builtins.tuple[builtins.object, ...]"
2706+
reveal_type(SplitSneaky(f2).kwargs) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
2707+
reveal_type(SplitSneaky(f3).args) # N: Revealed type is "builtins.tuple[builtins.object, ...]"
2708+
reveal_type(SplitSneaky(f3).kwargs) # N: Revealed type is "builtins.dict[builtins.str, builtins.object]"
2709+
[builtins fixtures/paramspec.pyi]
2710+
2711+
[case testRevealBoundParamSpecKwargs]
2712+
from typing import Callable, Generic, ParamSpec
2713+
from typing_extensions import Unpack, NotRequired, TypedDict
2714+
2715+
P = ParamSpec("P")
2716+
2717+
class Sneaky(Generic[P]):
2718+
def __init__(self, fn: Callable[P, object], *args: P.args, **kwargs: P.kwargs) -> None:
2719+
self.fn = fn
2720+
self.args = args
2721+
self.kwargs = kwargs
2722+
2723+
class Opt(TypedDict):
2724+
y: int
2725+
z: NotRequired[str]
2726+
2727+
def f1() -> int:
2728+
return 0
2729+
def f2(x: int) -> int:
2730+
return 0
2731+
def f3(x: int, /) -> int:
2732+
return 0
2733+
def f4(*, x: int) -> int:
2734+
return 0
2735+
def f5(x: int, y: int = 0) -> int:
2736+
return 0
2737+
def f6(**kwargs: int) -> int:
2738+
return 0
2739+
def f7(x: int, **kwargs: str) -> int:
2740+
return 0
2741+
def f8(x: int, /, **kwargs: str) -> int:
2742+
return 0
2743+
def f9(x: int, **kwargs: Unpack[Opt]) -> int:
2744+
return 0
2745+
2746+
reveal_type(Sneaky(f1).kwargs) # N: Revealed type is "builtins.dict[builtins.str, Never]"
2747+
reveal_type(Sneaky(f2, 1).kwargs) # N: Revealed type is "TypedDict('builtins.dict', {'x'?: builtins.int})"
2748+
reveal_type(Sneaky(f3, 1).kwargs) # N: Revealed type is "builtins.dict[builtins.str, Never]"
2749+
reveal_type(Sneaky(f4, x=1).kwargs) # N: Revealed type is "TypedDict('builtins.dict', {'x': builtins.int})"
2750+
reveal_type(Sneaky(f5, 1).kwargs) # N: Revealed type is "TypedDict('builtins.dict', {'x'?: builtins.int, 'y'?: builtins.int})"
2751+
reveal_type(Sneaky(f5, 1, 2).kwargs) # N: Revealed type is "TypedDict('builtins.dict', {'x'?: builtins.int, 'y'?: builtins.int})"
2752+
reveal_type(Sneaky(f6, x=1).kwargs) # N: Revealed type is "builtins.dict[builtins.str, builtins.int]"
2753+
reveal_type(Sneaky(f6, x=1, y=2).kwargs) # N: Revealed type is "builtins.dict[builtins.str, builtins.int]"
2754+
reveal_type(Sneaky(f7, 1, y='').kwargs) # N: Revealed type is "TypedDict('builtins.dict', {'x'?: builtins.int})"
2755+
reveal_type(Sneaky(f8, 1, y='').kwargs) # N: Revealed type is "builtins.dict[builtins.str, builtins.str]"
2756+
reveal_type(Sneaky(f9, 1, y=0).kwargs) # N: Revealed type is "TypedDict('builtins.dict', {'x'?: builtins.int, 'y': builtins.int, 'z'?: builtins.str})"
2757+
reveal_type(Sneaky(f9, 1, y=0, z='').kwargs) # N: Revealed type is "TypedDict('builtins.dict', {'x'?: builtins.int, 'y': builtins.int, 'z'?: builtins.str})"
2758+
[builtins fixtures/paramspec.pyi]

0 commit comments

Comments
 (0)