Skip to content

Commit 9be1b50

Browse files
committed
Trying to apply changes from CircuitPython
1 parent 6b2c390 commit 9be1b50

22 files changed

+1007
-130
lines changed

bleak/__init__.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
__author__ = """Henrik Blidh"""
88
__email__ = "[email protected]"
99

10+
from bleak.backends.circuitpython import patches
11+
patches.apply_patches()
12+
1013
import asyncio
1114
import functools
1215
import inspect
@@ -17,22 +20,25 @@
1720
from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable
1821
from types import TracebackType
1922
from typing import Any, Literal, Optional, TypedDict, Union, cast, overload
20-
21-
if sys.version_info < (3, 12):
22-
from typing_extensions import Buffer
23-
else:
24-
from collections.abc import Buffer
25-
26-
if sys.version_info < (3, 11):
27-
from async_timeout import timeout as async_timeout
28-
from typing_extensions import Never, Self, Unpack, assert_never
29-
else:
30-
from asyncio import timeout as async_timeout
31-
from typing import Never, Self, Unpack, assert_never
32-
33-
from bleak.args.bluez import BlueZScannerArgs
34-
from bleak.args.corebluetooth import CBScannerArgs, CBStartNotifyArgs
35-
from bleak.args.winrt import WinRTClientArgs
23+
from typing import TYPE_CHECKING
24+
25+
if sys.implementation.name == "cpython":
26+
if sys.version_info < (3, 12):
27+
from typing_extensions import Buffer
28+
else:
29+
from collections.abc import Buffer
30+
31+
if sys.version_info < (3, 11):
32+
from async_timeout import timeout as async_timeout
33+
from typing_extensions import Never, Self, Unpack, assert_never
34+
else:
35+
from asyncio import timeout as async_timeout
36+
from typing import Never, Self, Unpack, assert_never
37+
38+
if TYPE_CHECKING:
39+
from bleak.args.bluez import BlueZScannerArgs
40+
from bleak.args.corebluetooth import CBScannerArgs, CBStartNotifyArgs
41+
from bleak.args.winrt import WinRTClientArgs
3642
from bleak.backends.characteristic import BleakGATTCharacteristic
3743
from bleak.backends.client import BaseBleakClient, get_platform_client_backend_type
3844
from bleak.backends.descriptor import BleakGATTDescriptor
@@ -50,17 +56,18 @@
5056

5157
_logger = logging.getLogger(__name__)
5258
_logger.addHandler(logging.NullHandler())
53-
if bool(os.environ.get("BLEAK_LOGGING", False)):
54-
FORMAT = "%(asctime)-15s %(name)-8s %(threadName)s %(levelname)s: %(message)s"
55-
handler = logging.StreamHandler(sys.stderr)
56-
handler.setLevel(logging.DEBUG)
57-
handler.setFormatter(logging.Formatter(fmt=FORMAT))
58-
_logger.addHandler(handler)
59-
_logger.setLevel(logging.DEBUG)
59+
if sys.implementation.name == "cpython":
60+
if bool(os.environ.get("BLEAK_LOGGING", False)):
61+
FORMAT = "%(asctime)-15s %(name)-8s %(threadName)s %(levelname)s: %(message)s"
62+
handler = logging.StreamHandler(sys.stderr)
63+
handler.setLevel(logging.DEBUG)
64+
handler.setFormatter(logging.Formatter(fmt=FORMAT))
65+
_logger.addHandler(handler)
66+
_logger.setLevel(logging.DEBUG)
6067

6168

6269
# prevent tasks from being garbage collected
63-
_background_tasks = set[asyncio.Task[None]]()
70+
_background_tasks: set[asyncio.Task[None]] = set()
6471

6572

6673
class BleakScanner:
@@ -170,7 +177,7 @@ async def advertisement_data(
170177
171178
.. versionadded:: 0.21
172179
"""
173-
devices = asyncio.Queue[tuple[BLEDevice, AdvertisementData]]()
180+
devices: asyncio.Queue[tuple[BLEDevice, AdvertisementData]] = asyncio.Queue()
174181

175182
unregister_callback = self._backend.register_detection_callback(
176183
lambda bd, ad: devices.put_nowait((bd, ad))
@@ -181,7 +188,7 @@ async def advertisement_data(
181188
finally:
182189
unregister_callback()
183190

184-
class ExtraArgs(TypedDict, total=False):
191+
class ExtraArgs(TypedDict):
185192
"""
186193
Keyword args from :class:`~bleak.BleakScanner` that can be passed to
187194
other convenience methods.
@@ -372,7 +379,7 @@ async def find_device_by_filter(
372379
async for bd, ad in scanner.advertisement_data():
373380
if filterfunc(bd, ad):
374381
return bd
375-
assert_never(cast(Never, "advertisement_data() should never stop"))
382+
# assert_never(cast(Never, "advertisement_data() should never stop"))
376383
except asyncio.TimeoutError:
377384
return None
378385

bleak/backends/circuitpython/__init__.py

Whitespace-only changes.
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import sys
2+
3+
try:
4+
from typing import NamedTuple as _typing_namedtuple
5+
except ImportError:
6+
_typing_namedtuple = None
7+
8+
9+
def circuit_namedtuple(cls=None, *, annotations=None):
10+
if sys.implementation.name != "circuitpython":
11+
def passthrough(cls):
12+
if _typing_namedtuple is not None:
13+
return _typing_namedtuple(cls.__name__, [(k, v) for k, v in getattr(cls, "__annotations__", {}).items()])
14+
else:
15+
from collections import namedtuple as _collections_namedtuple
16+
fields = list(getattr(cls, "__annotations__", {}).keys())
17+
return _collections_namedtuple(cls.__name__, fields)
18+
return passthrough if cls is None else passthrough(cls)
19+
20+
# ---- CircuitPython ----
21+
def wrap(cls):
22+
nonlocal annotations
23+
if annotations is None:
24+
raise TypeError("You must provide annotations in CircuitPython")
25+
fields = list(annotations.keys())
26+
27+
class NT:
28+
__slots__ = fields
29+
30+
def __init__(self, **kwargs):
31+
for f in fields:
32+
if f in kwargs:
33+
setattr(self, f, kwargs.pop(f))
34+
else:
35+
raise TypeError(f"Missing field {f}")
36+
if kwargs:
37+
raise TypeError(f"Unexpected fields {list(kwargs.keys())}")
38+
39+
def __repr__(self):
40+
values = ", ".join(f"{f}={getattr(self,f)!r}" for f in fields)
41+
return f"{cls.__name__}({values})"
42+
43+
return NT
44+
45+
return wrap if cls is None else wrap(cls)
46+
47+
def circuit_advertisement_data_patch(cls=None):
48+
annotations = {
49+
"local_name": "Optional[str]",
50+
"manufacturer_data": "dict[int, bytes]",
51+
"service_data": "dict[str, bytes]",
52+
"service_uuids": "list[str]",
53+
"tx_power": "Optional[int]",
54+
"rssi": "int",
55+
"platform_data": "tuple[Any, ...]",
56+
}
57+
return circuit_namedtuple(cls, annotations=annotations)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import logging
2+
from typing import Optional, Any
3+
from typing_extensions import override
4+
5+
from _bleio import set_adapter
6+
from adafruit_ble import BLERadio, Advertisement, BLEConnection
7+
8+
from bleak import BleakScanner
9+
from bleak.backends.circuitpython.scanner import BleakScannerCircuitPython
10+
from bleak.backends.client import BaseBleakClient
11+
from bleak.backends.descriptor import BleakGATTDescriptor
12+
from bleak.backends.device import BLEDevice
13+
from bleak.exc import BleakError, BleakDeviceNotFoundError
14+
15+
logger = logging.getLogger(__name__)
16+
logger.setLevel(logging.DEBUG)
17+
18+
class BleakClientCircuitPython(BaseBleakClient):
19+
def __init__(
20+
self,
21+
address_or_ble_device: BLEDevice,
22+
services=None,
23+
**kwargs,
24+
):
25+
super().__init__(address_or_ble_device, **kwargs)
26+
_adapter = kwargs.get("adapter")
27+
if _adapter is not None:
28+
set_adapter(_adapter)
29+
30+
self._radio: Optional[BLERadio] = None
31+
self._advertisement: Optional[Advertisement] = None
32+
33+
if isinstance(address_or_ble_device, BLEDevice):
34+
self._radio, self._advertisement = address_or_ble_device.details
35+
36+
self._connection: Optional[BLEConnection] = None
37+
# self._services = None
38+
# self._is_connected = False
39+
# self._mtu = 23
40+
41+
@override
42+
async def connect(self, pair, dangerous_use_bleak_cache=False, **kwargs):
43+
logger.debug("Attempting to connect BLE device @ {}".format(self.address))
44+
45+
if self.is_connected:
46+
raise BleakError("Client is already connected")
47+
48+
if not self._advertisement.connectable:
49+
raise BleakError("Device is not connectable")
50+
51+
if pair:
52+
raise NotImplementedError("Not yet implemented")
53+
54+
timeout = kwargs.get("timeout", self._timeout)
55+
56+
if self._advertisement is None:
57+
logger.debug("Attempting to find BLE device @ {}".format(self.address))
58+
59+
device = await BleakScanner.find_device_by_address(
60+
self.address, timeout=timeout, backend=BleakScannerCircuitPython
61+
)
62+
if device:
63+
self._radio, self._advertisement = device.details
64+
else:
65+
raise BleakDeviceNotFoundError(
66+
self.address, f"Device @ {self.address} was not found"
67+
)
68+
69+
if self._radio is None:
70+
self._radio = BLERadio()
71+
72+
# TODO: disconnect_callback ?
73+
74+
logger.debug("Connecting to BLE device @ {}".format(self.address))
75+
76+
# TODO: wrap async
77+
self._connection = self._radio.connect(self._advertisement.address, timeout=timeout)
78+
logger.debug("Connected to BLE device @ {}".format(self.address))
79+
80+
logger.debug("Retrieving services from BLE device @ {}".format(self.address))
81+
82+
# TODO: get services
83+
84+
logger.debug("Services retrieved from BLE device @ {}".format(self.address))
85+
86+
87+
async def disconnect(self) -> None:
88+
"""Disconnect from the peripheral device"""
89+
logger.debug("Disconnecting from BLE device @ {}".format(self.address))
90+
if (
91+
self._radio is None
92+
or self._advertisement is None
93+
or not self.is_connected
94+
):
95+
logger.debug("Device is not connected @ {}".format(self.address))
96+
return
97+
98+
# TODO: wrap async
99+
self._connection.disconnect()
100+
logger.debug("Device disconnected @ {}".format(self.address))
101+
102+
@property
103+
@override
104+
def is_connected(self) -> bool:
105+
return self._connection is not None and self._connection.connected
106+
107+
@property
108+
@override
109+
def mtu_size(self) -> int:
110+
"""Get ATT MTU size for active connection"""
111+
# TODO: implement
112+
raise NotImplementedError()
113+
114+
@override
115+
async def pair(self, *args: Any, **kwargs: Any) -> None:
116+
raise NotImplementedError()
117+
118+
@override
119+
async def unpair(self) -> None:
120+
raise NotImplementedError()
121+
122+
@override
123+
async def read_gatt_char(self, characteristic, **kwargs):
124+
...
125+
126+
@override
127+
async def read_gatt_descriptor(self, descriptor: BleakGATTDescriptor, **kwargs):
128+
...
129+
130+
@override
131+
async def write_gatt_char(self, characteristic, data, response):
132+
...
133+
134+
@override
135+
async def write_gatt_descriptor(self, descriptor, data):
136+
...
137+
138+
@override
139+
async def start_notify(
140+
self,
141+
characteristic,
142+
callback,
143+
**kwargs,
144+
):
145+
...
146+
147+
@override
148+
async def stop_notify(self, characteristic):
149+
...
150+
151+
async def get_rssi(self) -> int:
152+
assert self._advertisement
153+
return self._advertisement.rssi
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from adafruit_ble import BLERadio
2+
3+
4+
class Manager:
5+
def __init__(self) -> None:
6+
pass
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Monkeypatch absent modules in CircuitPython"""
2+
import sys
3+
4+
5+
def apply_patches():
6+
if sys.implementation.name == "circuitpython":
7+
import builtins
8+
assert builtins
9+
import asyncio
10+
assert asyncio
11+
12+
from . import _fake_enum
13+
from . import _fake_types
14+
from . import _fake_typing
15+
from . import _inspect
16+
from . import _platform
17+
from . import _fake_abc
18+
from . import _fake_collections_abc
19+
from . import _uuid
20+
from . import _deprecation_warn
21+
from ._async_timeout import async_timeout as _async_timeout
22+
import adafruit_logging as _logging
23+
import circuitpython_functools as _functools
24+
25+
if 'builtins' not in sys.modules:
26+
sys.modules['builtins'] = builtins
27+
sys.modules["builtins"].DeprecationWarning = _deprecation_warn.DeprecationWarning
28+
sys.modules["enum"] = _fake_enum
29+
sys.modules["inspect"] = _inspect
30+
sys.modules["platform"] = _platform
31+
sys.modules["uuid"] = _uuid
32+
sys.modules["types"] = _fake_types
33+
sys.modules["typing"] = _fake_typing
34+
sys.modules["typing_extensions"] = _fake_typing
35+
sys.modules["logging"] = _logging
36+
sys.modules["abc"] = _fake_abc
37+
sys.modules["collections.abc"] = _fake_collections_abc
38+
sys.modules['functools'] = _functools
39+
sys.modules['asyncio'].timeout = _async_timeout
40+
41+
print("patched")

0 commit comments

Comments
 (0)