Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add static typing for _midlSAFEARRAY #580

Merged
merged 7 commits into from
Jul 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions comtypes/_safearray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""SAFEARRAY api functions, data types, and constants."""

from ctypes import *
from ctypes.wintypes import *
from ctypes import c_uint, c_ushort, c_void_p, POINTER, Structure, WinDLL
from ctypes.wintypes import DWORD, LONG, UINT

from comtypes import HRESULT, GUID

################################################################
Expand Down
11 changes: 6 additions & 5 deletions comtypes/automation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,14 @@

if TYPE_CHECKING:
from comtypes import hints # type: ignore

try:
from comtypes import _safearray
except (ImportError, AttributeError):
else:
try:
from comtypes import _safearray
except (ImportError, AttributeError):

class _safearray(object):
tagSAFEARRAY = None
class _safearray(object):
tagSAFEARRAY = None


LCID = DWORD
Expand Down
34 changes: 34 additions & 0 deletions comtypes/hints.pyi
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# This stub contains...
# - symbols those what might occur recursive imports in runtime.
# - utilities for type hints.
import ctypes
import sys
from typing import (
Any,
Callable,
ClassVar,
Generic,
Iterator,
List,
NoReturn,
Sequence,
Tuple,
Type,
TypeVar,
Expand Down Expand Up @@ -37,9 +40,11 @@ else:
from typing_extensions import Self

import comtypes
from comtypes import IUnknown as IUnknown, GUID as GUID
from comtypes.automation import IDispatch as IDispatch, VARIANT as VARIANT
from comtypes.server import IClassFactory as IClassFactory
from comtypes.typeinfo import ITypeInfo as ITypeInfo
from comtypes._safearray import tagSAFEARRAY as tagSAFEARRAY

Incomplete: TypeAlias = Any
"""The type symbol is used temporarily until the COM library parsers or
Expand All @@ -51,6 +56,35 @@ Hresult: TypeAlias = int
arguments and with `HRESULT` as its return type in its COM method definition.
"""

_CT = TypeVar("_CT", bound=ctypes._CData)
_T_IUnknown = TypeVar("_T_IUnknown", bound=IUnknown)
_T_Struct = TypeVar("_T_Struct", bound=ctypes.Structure)

class LP_SAFEARRAY(ctypes._Pointer[tagSAFEARRAY], Generic[_CT]):
contents: tagSAFEARRAY
_itemtype_: ClassVar[_CT] # type: ignore
_vartype_: ClassVar[int]
_needsfree: ClassVar[bool]

@overload
@classmethod
def create(
cls: Type[LP_SAFEARRAY[ctypes._Pointer[_T_IUnknown]]],
value: Sequence[_T_IUnknown],
extra: ctypes._Pointer[GUID] = ...,
) -> LP_SAFEARRAY[ctypes._Pointer[_T_IUnknown]]: ...
@overload
@classmethod
def create(cls, value: Sequence[_CT], extra: Any = ...) -> LP_SAFEARRAY[_CT]: ...
@overload
def unpack(
self: LP_SAFEARRAY[ctypes._Pointer[_T_IUnknown]],
) -> Sequence[_T_IUnknown]: ...
@overload
def unpack(self: LP_SAFEARRAY[_T_Struct]) -> Sequence[_T_Struct]: ...
@overload
def unpack(self) -> Sequence[Any]: ...

_T_coclass = TypeVar("_T_coclass", bound=comtypes.CoClass)

class FirstComItfOf(Generic[_T_coclass]):
Expand Down
13 changes: 10 additions & 3 deletions comtypes/safearray.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import threading
import array
from typing import TYPE_CHECKING
import comtypes
from ctypes import POINTER, Structure, byref, cast, c_long, memmove, pointer, sizeof
from comtypes import _safearray, IUnknown, com_interface_registry
from comtypes.patcher import Patch

if TYPE_CHECKING:
from typing import Type, TypeVar
from comtypes import hints # type: ignore

_CT = TypeVar("_CT", bound=comtypes._CData)

_safearray_type_cache = {}


Expand Down Expand Up @@ -49,18 +56,18 @@ def __bool__(self):

################################################################
# This is THE PUBLIC function: the gateway to the SAFEARRAY functionality.
def _midlSAFEARRAY(itemtype):
def _midlSAFEARRAY(itemtype: "Type[_CT]") -> "Type[hints.LP_SAFEARRAY[_CT]]":
"""This function mimics the 'SAFEARRAY(aType)' IDL idiom. It
returns a subtype of SAFEARRAY, instances will be built with a
typecode VT_... corresponding to the aType, which must be one of
the supported ctypes.
"""
try:
return POINTER(_safearray_type_cache[itemtype])
return POINTER(_safearray_type_cache[itemtype]) # type: ignore
except KeyError:
sa_type = _make_safearray_type(itemtype)
_safearray_type_cache[itemtype] = sa_type
return POINTER(sa_type)
return POINTER(sa_type) # type: ignore


def _make_safearray_type(itemtype):
Expand Down