Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion aws_embedded_metrics/environment/environment_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@
# limitations under the License.

import logging
import asyncio
import concurrent.futures
from collections.abc import Callable, Coroutine

from aws_embedded_metrics import config
from aws_embedded_metrics.environment import Environment
from aws_embedded_metrics.environment.default_environment import DefaultEnvironment
from aws_embedded_metrics.environment.lambda_environment import LambdaEnvironment
from aws_embedded_metrics.environment.local_environment import LocalEnvironment
from aws_embedded_metrics.environment.ec2_environment import EC2Environment
from typing import Optional
from typing import Optional, Any

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -73,3 +77,17 @@ async def resolve_environment() -> Environment:
log.info("No environment was detected. Using default.")
EnvironmentCache.environment = default_environment
return EnvironmentCache.environment


def resolve_environment_sync(
resolve_env_fn: Callable[[], Coroutine[Any, Any, Environment]] = resolve_environment
) -> Environment:
if EnvironmentCache.environment is not None:
return EnvironmentCache.environment
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(resolve_env_fn())
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(asyncio.run, resolve_env_fn()).result()
10 changes: 8 additions & 2 deletions aws_embedded_metrics/logger/metrics_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@

from datetime import datetime
from aws_embedded_metrics.environment import Environment
from aws_embedded_metrics.environment.environment_detector import resolve_environment_sync
from aws_embedded_metrics.logger.metrics_context import MetricsContext
from aws_embedded_metrics.validator import validate_namespace
from aws_embedded_metrics.utils import _await
from aws_embedded_metrics.config import get_config
from aws_embedded_metrics.storage_resolution import StorageResolution
from typing import Any, Awaitable, Callable, Dict, Tuple
Expand All @@ -34,17 +36,21 @@ def __init__(
self.context: MetricsContext = context or MetricsContext.empty()
self.flush_preserve_dimensions: bool = False

def flush_sync(self) -> None:
environment = resolve_environment_sync(lambda: _await(self.resolve_environment()))
self.__flush_with_environment(environment)

async def flush(self) -> None:
# resolve the environment and get the sink
# MOST of the time this will run synchonrously
# This only runs asynchronously if executing for the
# first time in a non-lambda environment
environment = await self.resolve_environment()
self.__flush_with_environment(environment)

def __flush_with_environment(self, environment: Environment) -> None:
self.__configure_context_for_environment(environment)
sink = environment.get_sink()

# accept and reset the context
sink.accept(self.context)
self.context = self.context.create_copy_with_context(self.flush_preserve_dimensions)

Expand Down
6 changes: 3 additions & 3 deletions aws_embedded_metrics/metric_scope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ def gen_wrapper(*args, **kwargs): # type: ignore
try:
for result in fn(*args, **kwargs):
if flush_on_yield:
asyncio.run(logger.flush())
logger.flush_sync()
yield result
finally:
asyncio.run(logger.flush())
logger.flush_sync()

return cast(F, gen_wrapper)

Expand Down Expand Up @@ -80,7 +80,7 @@ def wrapper(*args, **kwargs): # type: ignore
try:
return fn(*args, **kwargs)
finally:
asyncio.run(logger.flush())
logger.flush_sync()

return cast(F, wrapper)

Expand Down
11 changes: 11 additions & 0 deletions aws_embedded_metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@
# limitations under the License.

import time
from collections.abc import Awaitable
from datetime import datetime
from typing import TypeVar


T = TypeVar("T")


def now() -> int: return int(round(time.time() * 1000))


Expand All @@ -21,3 +28,7 @@ def convert_to_milliseconds(dt: datetime) -> int:
return 0

return int(round(dt.timestamp() * 1000))


async def _await(awaitable: Awaitable[T]) -> T:
return await awaitable
50 changes: 50 additions & 0 deletions tests/environment/test_environment_detector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from faker import Faker
import asyncio
import os
import pytest
from importlib import reload
Expand Down Expand Up @@ -87,3 +88,52 @@ async def test_resolve_environment_returns_override_lambda(before, monkeypatch):

# assert
assert isinstance(result, LambdaEnvironment)


@pytest.mark.asyncio
async def test_resolve_environment_sync_works_inside_running_event_loop(before, monkeypatch):
# arrange
monkeypatch.setenv("AWS_EMF_ENVIRONMENT", "default")
reload(config)
reload(environment_detector)
# verify we are inside a running loop
loop = asyncio.get_running_loop()
assert loop.is_running()

# act
result = environment_detector.resolve_environment_sync()

# assert
assert isinstance(result, DefaultEnvironment)


def test_resolve_environment_sync_with_async_resolve_env_fn(before):
# arrange
expected = DefaultEnvironment()

async def async_resolve():
return expected

# act
result = environment_detector.resolve_environment_sync(async_resolve)

# assert
assert result is expected


@pytest.mark.asyncio
async def test_resolve_environment_sync_with_async_resolve_env_fn_inside_running_loop(before):
# arrange
expected = DefaultEnvironment()

async def async_resolve():
return expected

loop = asyncio.get_running_loop()
assert loop.is_running()

# act
result = environment_detector.resolve_environment_sync(async_resolve)

# assert
assert result is expected
24 changes: 19 additions & 5 deletions tests/logger/test_metrics_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import aws_embedded_metrics.constants as constants
import pytest
from faker import Faker
from asyncio import Future
from importlib import reload
import os
import sys
Expand Down Expand Up @@ -509,6 +508,23 @@ async def test_can_set_timestamp(mocker):
context = get_flushed_context(sink)
assert context.meta[constants.TIMESTAMP] == utils.convert_to_milliseconds(expected_value)


def test_flush_sync_sends_context_to_sink(mocker):
# arrange
expected_key = fake.word()
expected_value = fake.word()

logger, sink, env = get_logger_and_sink(mocker)
logger.set_property(expected_key, expected_value)

# act
logger.flush_sync()

# assert
context = get_flushed_context(sink)
assert context.properties[expected_key] == expected_value


# Test helper methods


Expand All @@ -522,10 +538,8 @@ def before():
def get_logger_and_sink(mocker):
env = mocker.create_autospec(spec=Environment)

def env_provider():
result_future = Future()
result_future.set_result(env)
return result_future
async def env_provider():
return env

sink = mocker.create_autospec(spec=Sink)
env.get_sink.return_value = sink
Expand Down
42 changes: 42 additions & 0 deletions tests/metric_scope/test_metric_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ async def flush(self):
print("flush called")
InvocationTracker.record()

def flush_sync(self):
print("flush_sync called")
InvocationTracker.record()

MetricsLogger.flush = flush
MetricsLogger.flush_sync = flush_sync


@pytest.mark.asyncio
Expand Down Expand Up @@ -321,6 +326,43 @@ def my_handler():
get_config().default_flush_on_yield = original_flush_on_yield


@pytest.mark.asyncio
async def test_sync_scope_works_inside_running_event_loop(mock_logger):
# arrange
expected_result = True

@metric_scope
def my_handler():
return expected_result

# act
actual_result = my_handler()

# assert
assert expected_result == actual_result
assert InvocationTracker.invocations == 1


@pytest.mark.asyncio
async def test_sync_generator_works_inside_running_event_loop(mock_logger):
# arrange
expected_results = [1, 2, 3]

@metric_scope
def my_handler():
yield from expected_results

# act
actual_results = []
for result in my_handler():
actual_results.append(result)

# assert
assert actual_results == expected_results
# TODO in v4: change to 1 — default_flush_on_yield becomes False, flushing only once at completion
assert InvocationTracker.invocations == 4 # 3 yields + 1 final flush


# Test helpers


Expand Down