Skip to content

Commit

Permalink
Merge branch 'dev' into bug/587-bug-trains-randomly-stop
Browse files Browse the repository at this point in the history
  • Loading branch information
Lietze authored Jul 26, 2023
2 parents 2341f4c + 65f36b3 commit dcbcbbf
Show file tree
Hide file tree
Showing 21 changed files with 530 additions and 620 deletions.
131 changes: 4 additions & 127 deletions poetry.lock

Large diffs are not rendered by default.

19 changes: 16 additions & 3 deletions src/wrapper/simulation_object_updating_component.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from os import path
from typing import List
from typing import List, Optional

import sumolib
import traci
Expand Down Expand Up @@ -30,8 +30,17 @@ class SimulationObjectUpdatingComponent(Component):

_simulation_objects: List[SimulationObject]
_sumo_configuration = None
_tick: int = 0
infrastructure_provider: SumoInfrastructureProvider = None

@property
def tick(self) -> int:
"""Returns the current simulation tick
:return: The current simulation tick
"""
return self._tick

@property
def simulation_objects(self) -> List[SimulationObject]:
"""Returns a list of all objects in the simulation
Expand Down Expand Up @@ -99,7 +108,7 @@ def tracks(self) -> List[Track]:
def __init__(
self,
event_bus: EventBus = None,
sumo_configuration: str = os.getenv("SUMO_CONFIG_PATH"),
sumo_configuration: Optional[str] = os.getenv("SUMO_CONFIG_PATH"),
):
"""Creates a new SimulationObjectUpdatingComponent.
Expand Down Expand Up @@ -129,6 +138,7 @@ def add_subscriptions(self):
)

def next_tick(self, tick: int):
self._tick = tick
if tick == 1:
for signal in self.signals:
signal.set_incoming_index()
Expand All @@ -148,13 +158,16 @@ def _remove_stale_vehicles(self):
vehicles_to_remove = stored_vehicles - simulation_vehicles

for vehicle in vehicles_to_remove:
train = next(
train: Train = next(
(train for train in self.trains if train.identifier == vehicle)
)
self.infrastructure_provider.train_drove_off_track(train, train.edge)
self.event_bus.remove_train(self.tick, train.identifier)
self._simulation_objects.remove(train)

def _fetch_initial_simulation_objects(self):
assert self._sumo_configuration is not None

folder = path.dirname(self._sumo_configuration)
inputs = next(sumolib.xml.parse(self._sumo_configuration, "input"))
print(inputs)
Expand Down
36 changes: 24 additions & 12 deletions src/wrapper/simulation_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ class State(IntEnum):
_speed: float
_timetable: List[Platform]
state: State = State.DRIVING
_stop_state: bool
train_type: TrainType
reserved_tracks: List[ReservationTrack]
_station_index: int = 0
Expand Down Expand Up @@ -903,41 +904,36 @@ def __init__(
self.train_type = Train.TrainType.from_sumo_type(train_type, identifier)
self._timetable = timetable
self.reserved_tracks = []
self._stop_state = False
self._last_stop_state = False

if not from_simulator:
self._add_to_simulation(identifier, train_type, route_id)
self.train_type.max_speed = self.train_type._max_speed

def _add_to_simulation(self, identifier: str, train_type: str, route: str):
vehicle.add(identifier, routeID=route, typeID=train_type)
self.updater.event_bus.spawn_train(self.updater.tick, identifier)

def update(self, data: dict):
"""Gets called whenever a simualtion tick has happened.
:param updates: The updated values for the synchronized properties
"""
self._position = data[constants.VAR_POSITION]
edge_id = data[constants.VAR_ROAD_ID]
# self._route = data[constants.VAR_ROUTE]
self._speed = data[constants.VAR_SPEED]
self._stop_state = (data[constants.VAR_STOPSTATE] & 0b00010000) > 0

if (
not hasattr(self, "_edge")
or self._edge.identifier != edge_id
and not edge_id[:1] == ":"
):
if hasattr(self, "_edge"):
if edge_id not in list(
map(lambda obj: obj.identifier, self._edge.to_node.edges)
):
raise ValueError(
(
"A Track was skipped: Old track: "
f"{self._edge.identifier}, new track: {edge_id}"
)
)

self.updater.infrastructure_provider.train_drove_off_track(
self, self._edge
)

self._edge = next(
item for item in self.updater.edges if item.identifier == edge_id
)
Expand All @@ -951,16 +947,32 @@ def update(self, data: dict):
self, self._edge
)

if self._stop_state and not self._last_stop_state:
self.updater.event_bus.arrival_train(
self.updater.tick,
self.identifier,
self.timetable[self._station_index - 1],
)
self._last_stop_state = True

if not self._stop_state and self._last_stop_state:
self.updater.event_bus.departure_train(
self.updater.tick,
self.identifier,
self.timetable[self._station_index - 1],
)
self._last_stop_state = False

def add_subscriptions(self) -> List[int]:
"""Gets called when this object is created to allow
specification of simulator-synchronized properties.
:return: The synchronized properties (see <https://sumo.dlr.de/pydoc/traci.constants.html>)
"""
return [
constants.VAR_POSITION,
# constants.VAR_ROUTE,
constants.VAR_ROAD_ID,
constants.VAR_SPEED,
constants.VAR_STOPSTATE,
]

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"tests.fixtures.fixtures_logger",
"tests.fixtures.fixtures_model",
"tests.fixtures.fixtures_spawner",
"tests.fixtures.fixtures_wrapper",
"celery.contrib.pytest",
]

Expand Down
107 changes: 0 additions & 107 deletions tests/fault_injector/conftest.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
import pytest
from traci import vehicle

from src.event_bus.event_bus import EventBus
from src.implementor.models import Run, SimulationConfiguration, Token
from src.interlocking_component.route_controller import IInterlockingDisruptor
from src.logger.logger import Logger
from src.schedule.schedule import ScheduleConfiguration
from src.spawner.spawner import (
Spawner,
SpawnerConfiguration,
SpawnerConfigurationXSchedule,
)
from src.wrapper.simulation_object_updating_component import (
SimulationObjectUpdatingComponent,
)
from src.wrapper.simulation_objects import Edge, Platform, Track, Train


@pytest.fixture
Expand Down Expand Up @@ -42,13 +35,6 @@ def run(simulation_configuration):
return Run.create(simulation_configuration=simulation_configuration.id)


@pytest.fixture
def event_bus(run):
bus = EventBus(run_id=run.id)
Logger(bus)
return bus


class MockRouteController:
"""Mock up for RouteController"""

Expand All @@ -63,99 +49,6 @@ def interlocking_disruptor():
return IInterlockingDisruptor(MockRouteController())


@pytest.fixture
def simulation_object_updater():
return SimulationObjectUpdatingComponent()


@pytest.fixture
def edge() -> Edge:
return Edge("fault injector track")


@pytest.fixture
def edge_re() -> Edge:
return Edge("fault injector track-re")


@pytest.fixture
def platform() -> Platform:
return Platform("fancy-platform", platform_id="platform-1", edge_id="fancy-edge")


# pylint: disable=protected-access
@pytest.fixture
def track(edge, edge_re):
track = Track(edge, edge_re)
edge._track = track
edge_re._track = track
return track


# pylint: enable=protected-access


@pytest.fixture
def combine_track_and_wrapper(
track: Track, simulation_object_updater: SimulationObjectUpdatingComponent
):
track.updater = simulation_object_updater
simulation_object_updater.simulation_objects.append(track)
return track, simulation_object_updater


@pytest.fixture
def combine_platform_and_wrapper(
platform: Platform, simulation_object_updater: SimulationObjectUpdatingComponent
):
platform.updater = simulation_object_updater
simulation_object_updater.simulation_objects.append(platform)
return platform, simulation_object_updater


@pytest.fixture
def combine_train_and_wrapper(
train: Train, simulation_object_updater: SimulationObjectUpdatingComponent
):
train.updater = simulation_object_updater
simulation_object_updater.simulation_objects.append(train)
return train, simulation_object_updater


# pylint: disable=invalid-name, unused-argument
# disabling invalid-name allows the names routeID and typeID, despite the fact,
# that they don't follow snake_case
@pytest.fixture
def train_add(monkeypatch):
def add_train(identifier, routeID=None, typeID=None):
assert identifier is not None
assert typeID is not None

monkeypatch.setattr(vehicle, "add", add_train)


# pylint: enable=invalid-name, unused-argument


@pytest.fixture
def max_speed(monkeypatch):
# pylint: disable-next=unused-argument
def set_max_speed(train_id: str, speed: float):
pass

monkeypatch.setattr(vehicle, "setMaxSpeed", set_max_speed)


@pytest.fixture
# pylint: disable-next=unused-argument
def train(train_add, max_speed) -> Train:
return Train(
identifier="fault injector train",
train_type="cargo",
timetable=["platform-1", "platform-2"],
)


@pytest.fixture
def schedule():
schedule_configuration = ScheduleConfiguration(
Expand Down
8 changes: 4 additions & 4 deletions tests/fault_injector/test_affected_element_does_not_exist.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ def fault(
self,
fault_configuration: FaultConfiguration,
event_bus: EventBus,
simulation_object_updater: SimulationObjectUpdatingComponent,
souc: SimulationObjectUpdatingComponent,
interlocking_disruptor: IInterlockingDisruptor,
fault_type: Fault,
):
return fault_type(
configuration=fault_configuration,
event_bus=event_bus,
simulation_object_updater=simulation_object_updater,
simulation_object_updater=souc,
interlocking_disruptor=interlocking_disruptor,
)

Expand Down Expand Up @@ -165,14 +165,14 @@ def schedule_blocked_fault(
self,
schedule_blocked_fault_configuration: ScheduleBlockedFaultConfiguration,
event_bus: EventBus,
simulation_object_updater: SimulationObjectUpdatingComponent,
souc: SimulationObjectUpdatingComponent,
interlocking_disruptor: IInterlockingDisruptor,
spawner: Spawner,
):
return ScheduleBlockedFault(
configuration=schedule_blocked_fault_configuration,
event_bus=event_bus,
simulation_object_updater=simulation_object_updater,
simulation_object_updater=souc,
interlocking_disruptor=interlocking_disruptor,
spawner=spawner,
)
Expand Down
7 changes: 5 additions & 2 deletions tests/fault_injector/test_fault.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
FaultConfiguration,
)
from src.fault_injector.fault_types.fault import Fault
from src.wrapper.simulation_object_updating_component import (
SimulationObjectUpdatingComponent,
)


class TestFault:
Expand Down Expand Up @@ -40,13 +43,13 @@ def fault(
self,
configuration,
event_bus,
simulation_object_updater,
souc: SimulationObjectUpdatingComponent,
interlocking_disruptor,
):
return self.MockSpecialFault(
configuration,
event_bus,
simulation_object_updater,
souc,
interlocking_disruptor,
)

Expand Down
Loading

0 comments on commit dcbcbbf

Please sign in to comment.