Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework discovery timeout logic #153

Merged
merged 1 commit into from
May 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
packages=setuptools.find_packages(exclude=["tests", "tests.*"]),
install_requires=[
"aiohttp>=3.5.4, <4",
"async_timeout>=4.0.2",
"voluptuous>=0.11.5",
"importlib_metadata>=3.6; python_version<'3.10'",
"typing_extensions>=4.1.0; python_version<'3.11'",
Expand Down
17 changes: 11 additions & 6 deletions solax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,21 @@
import asyncio
import logging

from async_timeout import timeout

from solax.discovery import discover
from solax.inverter import Inverter, InverterResponse
from solax.inverter_http_client import REQUEST_TIMEOUT

_LOGGER = logging.getLogger(__name__)


REQUEST_TIMEOUT = 5
__all__ = (
"discover",
"real_time_api",
"rt_request",
"Inverter",
"InverterResponse",
"RealTimeAPI",
"REQUEST_TIMEOUT",
)


async def rt_request(inv: Inverter, retry, t_wait=0) -> InverterResponse:
Expand All @@ -23,8 +29,7 @@ async def rt_request(inv: Inverter, retry, t_wait=0) -> InverterResponse:
new_wait = (t_wait * 2) + 5
retry = retry - 1
try:
async with timeout(REQUEST_TIMEOUT):
return await inv.get_data()
return await inv.get_data()
except asyncio.TimeoutError:
if retry > 0:
return await rt_request(inv, retry, new_wait)
Expand Down
164 changes: 79 additions & 85 deletions solax/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import sys
from asyncio import Future, Task
from collections import defaultdict
from typing import Dict, Literal, Optional, Sequence, Set, TypedDict, Union, cast

from async_timeout import timeout
from typing import Dict, Literal, Sequence, Set, TypedDict, Union, cast

from solax.inverter import Inverter
from solax.inverter_http_client import InverterHttpClient
Expand All @@ -29,7 +27,6 @@


class DiscoveryKeywords(TypedDict, total=False):
timeout: Optional[float]
inverters: Sequence[Inverter]
return_when: Union[Literal["ALL_COMPLETED"], Literal["FIRST_COMPLETED"]]

Expand Down Expand Up @@ -72,89 +69,86 @@ async def _discovery_task(i) -> Inverter:
async def discover(
host, port, pwd="", **kwargs: Unpack[DiscoveryKeywords]
) -> Union[Inverter, Set[Inverter]]:
async with timeout(kwargs.get("timeout", 15)):
done: Set[_InverterTask] = set()
pending: Set[_InverterTask] = set()
failures = set()
requests: Dict[InverterHttpClient, Future] = defaultdict(
asyncio.get_running_loop().create_future
)

return_when = kwargs.get("return_when", asyncio.FIRST_COMPLETED)
for cls in kwargs.get("inverters", REGISTRY):
for inverter in cls.build_all_variants(host, port, pwd):
inverter.http_client = cast(
InverterHttpClient,
_DiscoveryHttpClient(
inverter, inverter.http_client, requests[inverter.http_client]
),
)

pending.add(
asyncio.create_task(_discovery_task(inverter), name=f"{inverter}")
)

if not pending:
raise DiscoveryError("No inverters to try to discover")

def cancel(pending: Set[_InverterTask]) -> Set[_InverterTask]:
for task in pending:
task.cancel()
return pending

def remove_failures_from(done: Set[_InverterTask]) -> None:
for task in set(done):
exc = task.exception()
if exc:
failures.add(exc)
done.remove(task)

# stagger HTTP request to prevent accidental Denial Of Service
async def stagger() -> None:
for http_client, future in requests.items():
future.set_result(asyncio.create_task(http_client.request()))
await asyncio.sleep(1)

staggered = asyncio.create_task(stagger())

while pending and (not done or return_when != asyncio.FIRST_COMPLETED):
try:
done, pending = await asyncio.wait(pending, return_when=return_when)
except asyncio.CancelledError:
staggered.cancel()
await asyncio.gather(
staggered, *cancel(pending), return_exceptions=True
)
raise

remove_failures_from(done)

if done and return_when == asyncio.FIRST_COMPLETED:
break

logging.debug("%d discovery tasks are still running...", len(pending))

if pending and return_when != asyncio.FIRST_COMPLETED:
pending.update(done)
done.clear()
done: Set[_InverterTask] = set()
pending: Set[_InverterTask] = set()
failures = set()
requests: Dict[InverterHttpClient, Future] = defaultdict(
asyncio.get_running_loop().create_future
)

return_when = kwargs.get("return_when", asyncio.FIRST_COMPLETED)
for cls in kwargs.get("inverters", REGISTRY):
for inverter in cls.build_all_variants(host, port, pwd):
inverter.http_client = cast(
InverterHttpClient,
_DiscoveryHttpClient(
inverter, inverter.http_client, requests[inverter.http_client]
),
)

pending.add(
asyncio.create_task(_discovery_task(inverter), name=f"{inverter}")
)

if not pending:
raise DiscoveryError("No inverters to try to discover")

def cancel(pending: Set[_InverterTask]) -> Set[_InverterTask]:
for task in pending:
task.cancel()
return pending

def remove_failures_from(done: Set[_InverterTask]) -> None:
for task in set(done):
exc = task.exception()
if exc:
failures.add(exc)
done.remove(task)

# stagger HTTP request to prevent accidental Denial Of Service
async def stagger() -> None:
for http_client, future in requests.items():
future.set_result(asyncio.create_task(http_client.request()))
await asyncio.sleep(1)

staggered = asyncio.create_task(stagger())

while pending and (not done or return_when != asyncio.FIRST_COMPLETED):
try:
done, pending = await asyncio.wait(pending, return_when=return_when)
except asyncio.CancelledError:
staggered.cancel()
await asyncio.gather(staggered, *cancel(pending), return_exceptions=True)
raise

remove_failures_from(done)
staggered.cancel()
await asyncio.gather(staggered, *cancel(pending), return_exceptions=True)

if done:
logging.info("Discovered inverters: %s", {task.result() for task in done})
if return_when == asyncio.FIRST_COMPLETED:
return await next(iter(done))

return {task.result() for task in done}

raise DiscoveryError(
"Unable to connect to the inverter at "
f"host={host} port={port}, or your inverter is not supported yet.\n"
"Please see https://github.com/squishykid/solax/wiki/DiscoveryError\n"
f"Failures={str(failures)}"
)

if done and return_when == asyncio.FIRST_COMPLETED:
break

logging.debug("%d discovery tasks are still running...", len(pending))

if pending and return_when != asyncio.FIRST_COMPLETED:
pending.update(done)
done.clear()

remove_failures_from(done)
staggered.cancel()
await asyncio.gather(staggered, *cancel(pending), return_exceptions=True)

if done:
logging.info("Discovered inverters: %s", {task.result() for task in done})
if return_when == asyncio.FIRST_COMPLETED:
return await next(iter(done))

return {task.result() for task in done}

raise DiscoveryError(
"Unable to connect to the inverter at "
f"host={host} port={port}, or your inverter is not supported yet.\n"
"Please see https://github.com/squishykid/solax/wiki/DiscoveryError\n"
f"Failures={str(failures)}"
)


class DiscoveryError(Exception):
Expand Down
10 changes: 8 additions & 2 deletions solax/inverter_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
if sys.version_info >= (3, 10):
from dataclasses import KW_ONLY


REQUEST_TIMEOUT = 5.0
_CACHE: WeakValueDictionary[int, InverterHttpClient] = WeakValueDictionary()


Expand Down Expand Up @@ -107,7 +109,9 @@ async def request(self):
async def get(self):
url = self.url + "?" + self.query if self.query else self.url
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=self.headers) as req:
async with session.get(
url, headers=self.headers, timeout=REQUEST_TIMEOUT
) as req:
req.raise_for_status()
resp = await req.read()
return resp
Expand All @@ -116,7 +120,9 @@ async def post(self):
url = self.url + "?" + self.query if self.query else self.url
data = self.data.encode("utf-8") if self.data else None
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=self.headers, data=data) as req:
async with session.post(
url, headers=self.headers, data=data, timeout=REQUEST_TIMEOUT
) as req:
req.raise_for_status()
resp = await req.read()
return resp
Expand Down