diff --git a/aws_embedded_metrics/environment/environment_detector.py b/aws_embedded_metrics/environment/environment_detector.py index 9b3d320..aaa70c7 100644 --- a/aws_embedded_metrics/environment/environment_detector.py +++ b/aws_embedded_metrics/environment/environment_detector.py @@ -12,13 +12,17 @@ # limitations under the License. import logging +import asyncio +import concurrent.futures +from collections.abc import Callable, Coroutine + from aws_embedded_metrics import config from aws_embedded_metrics.environment import Environment from aws_embedded_metrics.environment.default_environment import DefaultEnvironment from aws_embedded_metrics.environment.lambda_environment import LambdaEnvironment from aws_embedded_metrics.environment.local_environment import LocalEnvironment from aws_embedded_metrics.environment.ec2_environment import EC2Environment -from typing import Optional +from typing import Optional, Any log = logging.getLogger(__name__) @@ -73,3 +77,17 @@ async def resolve_environment() -> Environment: log.info("No environment was detected. Using default.") EnvironmentCache.environment = default_environment return EnvironmentCache.environment + + +def resolve_environment_sync( + resolve_env_fn: Callable[[], Coroutine[Any, Any, Environment]] = resolve_environment +) -> Environment: + if EnvironmentCache.environment is not None: + return EnvironmentCache.environment + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(resolve_env_fn()) + else: + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + return pool.submit(asyncio.run, resolve_env_fn()).result() diff --git a/aws_embedded_metrics/logger/metrics_logger.py b/aws_embedded_metrics/logger/metrics_logger.py index ebbe469..70fbb26 100644 --- a/aws_embedded_metrics/logger/metrics_logger.py +++ b/aws_embedded_metrics/logger/metrics_logger.py @@ -13,8 +13,10 @@ from datetime import datetime from aws_embedded_metrics.environment import Environment +from aws_embedded_metrics.environment.environment_detector import resolve_environment_sync from aws_embedded_metrics.logger.metrics_context import MetricsContext from aws_embedded_metrics.validator import validate_namespace +from aws_embedded_metrics.utils import _await from aws_embedded_metrics.config import get_config from aws_embedded_metrics.storage_resolution import StorageResolution from typing import Any, Awaitable, Callable, Dict, Tuple @@ -34,17 +36,21 @@ def __init__( self.context: MetricsContext = context or MetricsContext.empty() self.flush_preserve_dimensions: bool = False + def flush_sync(self) -> None: + environment = resolve_environment_sync(lambda: _await(self.resolve_environment())) + self.__flush_with_environment(environment) + async def flush(self) -> None: # resolve the environment and get the sink # MOST of the time this will run synchonrously # This only runs asynchronously if executing for the # first time in a non-lambda environment environment = await self.resolve_environment() + self.__flush_with_environment(environment) + def __flush_with_environment(self, environment: Environment) -> None: self.__configure_context_for_environment(environment) sink = environment.get_sink() - - # accept and reset the context sink.accept(self.context) self.context = self.context.create_copy_with_context(self.flush_preserve_dimensions) diff --git a/aws_embedded_metrics/metric_scope/__init__.py b/aws_embedded_metrics/metric_scope/__init__.py index 2d597c0..158a5c8 100644 --- a/aws_embedded_metrics/metric_scope/__init__.py +++ b/aws_embedded_metrics/metric_scope/__init__.py @@ -49,10 +49,10 @@ def gen_wrapper(*args, **kwargs): # type: ignore try: for result in fn(*args, **kwargs): if flush_on_yield: - asyncio.run(logger.flush()) + logger.flush_sync() yield result finally: - asyncio.run(logger.flush()) + logger.flush_sync() return cast(F, gen_wrapper) @@ -80,7 +80,7 @@ def wrapper(*args, **kwargs): # type: ignore try: return fn(*args, **kwargs) finally: - asyncio.run(logger.flush()) + logger.flush_sync() return cast(F, wrapper) diff --git a/aws_embedded_metrics/utils.py b/aws_embedded_metrics/utils.py index 8753cd6..bb6b9e9 100644 --- a/aws_embedded_metrics/utils.py +++ b/aws_embedded_metrics/utils.py @@ -12,7 +12,14 @@ # limitations under the License. import time +from collections.abc import Awaitable from datetime import datetime +from typing import TypeVar + + +T = TypeVar("T") + + def now() -> int: return int(round(time.time() * 1000)) @@ -21,3 +28,7 @@ def convert_to_milliseconds(dt: datetime) -> int: return 0 return int(round(dt.timestamp() * 1000)) + + +async def _await(awaitable: Awaitable[T]) -> T: + return await awaitable diff --git a/tests/environment/test_environment_detector.py b/tests/environment/test_environment_detector.py index 1a2d117..05c8f7a 100644 --- a/tests/environment/test_environment_detector.py +++ b/tests/environment/test_environment_detector.py @@ -1,4 +1,5 @@ from faker import Faker +import asyncio import os import pytest from importlib import reload @@ -87,3 +88,52 @@ async def test_resolve_environment_returns_override_lambda(before, monkeypatch): # assert assert isinstance(result, LambdaEnvironment) + + +@pytest.mark.asyncio +async def test_resolve_environment_sync_works_inside_running_event_loop(before, monkeypatch): + # arrange + monkeypatch.setenv("AWS_EMF_ENVIRONMENT", "default") + reload(config) + reload(environment_detector) + # verify we are inside a running loop + loop = asyncio.get_running_loop() + assert loop.is_running() + + # act + result = environment_detector.resolve_environment_sync() + + # assert + assert isinstance(result, DefaultEnvironment) + + +def test_resolve_environment_sync_with_async_resolve_env_fn(before): + # arrange + expected = DefaultEnvironment() + + async def async_resolve(): + return expected + + # act + result = environment_detector.resolve_environment_sync(async_resolve) + + # assert + assert result is expected + + +@pytest.mark.asyncio +async def test_resolve_environment_sync_with_async_resolve_env_fn_inside_running_loop(before): + # arrange + expected = DefaultEnvironment() + + async def async_resolve(): + return expected + + loop = asyncio.get_running_loop() + assert loop.is_running() + + # act + result = environment_detector.resolve_environment_sync(async_resolve) + + # assert + assert result is expected diff --git a/tests/logger/test_metrics_logger.py b/tests/logger/test_metrics_logger.py index 08bd971..c7b7964 100644 --- a/tests/logger/test_metrics_logger.py +++ b/tests/logger/test_metrics_logger.py @@ -8,7 +8,6 @@ import aws_embedded_metrics.constants as constants import pytest from faker import Faker -from asyncio import Future from importlib import reload import os import sys @@ -509,6 +508,23 @@ async def test_can_set_timestamp(mocker): context = get_flushed_context(sink) assert context.meta[constants.TIMESTAMP] == utils.convert_to_milliseconds(expected_value) + +def test_flush_sync_sends_context_to_sink(mocker): + # arrange + expected_key = fake.word() + expected_value = fake.word() + + logger, sink, env = get_logger_and_sink(mocker) + logger.set_property(expected_key, expected_value) + + # act + logger.flush_sync() + + # assert + context = get_flushed_context(sink) + assert context.properties[expected_key] == expected_value + + # Test helper methods @@ -522,10 +538,8 @@ def before(): def get_logger_and_sink(mocker): env = mocker.create_autospec(spec=Environment) - def env_provider(): - result_future = Future() - result_future.set_result(env) - return result_future + async def env_provider(): + return env sink = mocker.create_autospec(spec=Sink) env.get_sink.return_value = sink diff --git a/tests/metric_scope/test_metric_scope.py b/tests/metric_scope/test_metric_scope.py index 6496651..2405a61 100644 --- a/tests/metric_scope/test_metric_scope.py +++ b/tests/metric_scope/test_metric_scope.py @@ -16,7 +16,12 @@ async def flush(self): print("flush called") InvocationTracker.record() + def flush_sync(self): + print("flush_sync called") + InvocationTracker.record() + MetricsLogger.flush = flush + MetricsLogger.flush_sync = flush_sync @pytest.mark.asyncio @@ -321,6 +326,43 @@ def my_handler(): get_config().default_flush_on_yield = original_flush_on_yield +@pytest.mark.asyncio +async def test_sync_scope_works_inside_running_event_loop(mock_logger): + # arrange + expected_result = True + + @metric_scope + def my_handler(): + return expected_result + + # act + actual_result = my_handler() + + # assert + assert expected_result == actual_result + assert InvocationTracker.invocations == 1 + + +@pytest.mark.asyncio +async def test_sync_generator_works_inside_running_event_loop(mock_logger): + # arrange + expected_results = [1, 2, 3] + + @metric_scope + def my_handler(): + yield from expected_results + + # act + actual_results = [] + for result in my_handler(): + actual_results.append(result) + + # assert + assert actual_results == expected_results + # TODO in v4: change to 1 — default_flush_on_yield becomes False, flushing only once at completion + assert InvocationTracker.invocations == 4 # 3 yields + 1 final flush + + # Test helpers