Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add replay logging mechanism #802

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
62 changes: 62 additions & 0 deletions openadapt/alembic/versions/c84664aeb5ae_add_replay_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""add_replay_models

Revision ID: c84664aeb5ae
Revises: bb25e889ad71
Create Date: 2024-06-25 15:05:09.110171

"""
from alembic import op
import sqlalchemy as sa

import openadapt

# revision identifiers, used by Alembic.
revision = "c84664aeb5ae"
down_revision = "bb25e889ad71"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"replay",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"timestamp",
openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False),
nullable=True,
),
sa.Column("strategy_name", sa.String(), nullable=True),
sa.Column("strategy_args", sa.JSON(), nullable=True),
sa.Column("git_hash", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id", name=op.f("pk_replay")),
)
op.create_table(
"replay_log",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("replay_id", sa.Integer(), nullable=True),
sa.Column("lineno", sa.Integer(), nullable=True),
sa.Column("filename", sa.String(), nullable=True),
sa.Column("git_hash", sa.String(), nullable=True),
sa.Column(
"timestamp",
openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False),
nullable=True,
),
sa.Column("log_level", sa.String(), nullable=True),
sa.Column("key", sa.String(), nullable=True),
sa.Column("data", sa.LargeBinary(), nullable=True),
sa.ForeignKeyConstraint(
["replay_id"], ["replay.id"], name=op.f("fk_replay_log_replay_id_replay")
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_replay_log")),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("replay_log")
op.drop_table("replay")
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion openadapt/app/tray.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def update_args_inputs() -> None:
logger.info(f"kwargs=\n{pformat(kwargs)}")

self.child_conn.send({"type": "replay.starting"})
record_replay = False
record_replay = True
recording_timestamp = None
strategy_name = selected_strategy.__name__
replay_proc = multiprocessing.Process(
Expand Down
71 changes: 70 additions & 1 deletion openadapt/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import asyncio
import json
import os
import pickle
import sys
import time

from loguru import logger
Expand All @@ -15,7 +17,6 @@
import psutil
import sqlalchemy as sa


from openadapt import utils
from openadapt.config import DATABASE_LOCK_FILE_PATH, config
from openadapt.db.db import Session, get_read_only_session_maker
Expand All @@ -25,6 +26,8 @@
MemoryStat,
PerformanceStat,
Recording,
Replay,
ReplayLog,
Screenshot,
ScrubbedRecording,
WindowEvent,
Expand Down Expand Up @@ -725,6 +728,72 @@ def get_audio_info(
return audio_infos[0] if audio_infos else None


def add_replay(
session: SaSession,
strategy_name: str,
strategy_args: dict,
) -> int:
"""Add a replay to the database.

Args:
session (sa.orm.Session): The database session.
strategy_name (str): The name of the replay strategy.
strategy_args (dict): The arguments of the replay strategy.

Returns:
int: The id of the replay.
"""
git_hash = utils.get_git_hash()
timestamp = utils.get_timestamp()
replay = Replay(
timestamp=timestamp,
strategy_name=strategy_name,
strategy_args=strategy_args,
git_hash=git_hash,
)
session.add(replay)
session.commit()
session.refresh(replay)
return replay.id


def add_replay_log(*, replay_id: int, log_level: str, key: str, data: Any) -> None:
"""Add a replay log entry to the database.

Args:
replay_id (int): The id of the replay.
log_level (str): The log level of the log entry.
key (str): The key of the log entry.
data (Any): The data of the log entry.
"""
with get_new_session(read_and_write=True) as session:
pickled_data = pickle.dumps(data)

frame = sys._getframe(1)
caller_line = frame.f_lineno
caller_file = frame.f_code.co_filename

git_hash = utils.get_git_hash()
timestamp = utils.get_timestamp()

logger.info(
f"{caller_line=}, {caller_file=}, {git_hash=}, {timestamp=}, {log_level=},"
f" {key=}"
)
replay_log = ReplayLog(
replay_id=replay_id,
lineno=caller_line,
filename=caller_file,
git_hash=git_hash,
timestamp=timestamp,
log_level=log_level,
key=key,
data=pickled_data,
)
session.add(replay_log)
session.commit()


def post_process_events(session: SaSession, recording: Recording) -> None:
"""Post-process events.

Expand Down
39 changes: 38 additions & 1 deletion openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from openadapt.privacy.base import ScrubbingProvider, TextScrubbingMixin
from openadapt.privacy.providers import ScrubProvider


EMPTY_VALS = [None, "", [], (), {}]


Expand Down Expand Up @@ -182,6 +181,26 @@ def __init__(self, **kwargs: dict) -> None:
for key, value in properties.items():
setattr(self, key, value)

def to_log_dict(self) -> dict[str, Any]:
Copy link
Member

@abrichr abrichr Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use asdict()?

Edit: if the goal is to remove unnecessary properties, what do you think about overriding asdict, calling the super method, then removing the unnecessary keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes that makes more sense

"""Convert the action event to a log dictionary."""
return {
"name": self.name,
"timestamp": self.timestamp,
"mouse_x": self.mouse_x,
"mouse_y": self.mouse_y,
"mouse_dx": self.mouse_dx,
"mouse_dy": self.mouse_dy,
"mouse_button_name": self.mouse_button_name,
"mouse_pressed": self.mouse_pressed,
"key_name": self.key_name,
"key_char": self.key_char,
"key_vk": self.key_vk,
"canonical_key_name": self.canonical_key_name,
"canonical_key_char": self.canonical_key_char,
"canonical_key_vk": self.canonical_key_vk,
"element_state": self.element_state,
}

@property
def available_segment_descriptions(self) -> list[str]:
"""Gets the available segment descriptions."""
Expand Down Expand Up @@ -930,6 +949,24 @@ class Replay(db.Base):
strategy_name = sa.Column(sa.String)
strategy_args = sa.Column(sa.JSON)
git_hash = sa.Column(sa.String)
logs = sa.orm.relationship("ReplayLog", back_populates="replay")


class ReplayLog(db.Base):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about renaming "Replay" to "Execution" and "ReplayLog" to "ExecutionLog"?

"""Class representing a replay log in the database."""

__tablename__ = "replay_log"

id = sa.Column(sa.Integer, primary_key=True)
replay_id = sa.Column(sa.ForeignKey("replay.id"))
replay = sa.orm.relationship("Replay", back_populates="logs")
lineno = sa.Column(sa.Integer)
filename = sa.Column(sa.String)
git_hash = sa.Column(sa.String)
timestamp = sa.Column(ForceFloat)
log_level = sa.Column(sa.String)
key = sa.Column(sa.String)
data = sa.Column(sa.LargeBinary)


def copy_sa_instance(sa_instance: db.Base, **kwargs: dict) -> db.Base:
Expand Down
14 changes: 13 additions & 1 deletion openadapt/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
with redirect_stdout_stderr():
import fire

from openadapt import capture as _capture, utils
from openadapt import capture as _capture
from openadapt import utils
from openadapt.config import CAPTURE_DIR_PATH, print_config
from openadapt.db import crud
from openadapt.models import Recording
Expand Down Expand Up @@ -48,6 +49,7 @@ def replay(
Returns:
bool: True if replay was successful, None otherwise.
"""
utils.set_start_time()
utils.configure_logging(logger, LOG_LEVEL)
print_config()
posthog.capture(event="replay.started", properties={"strategy_name": strategy_name})
Expand Down Expand Up @@ -81,9 +83,17 @@ def replay(
strategy_class = strategy_class_by_name[strategy_name]
logger.info(f"{strategy_class=}")

write_session = crud.get_new_session(read_and_write=True)
replay_id = crud.add_replay(write_session, strategy_name, strategy_args=kwargs)

strategy = strategy_class(recording, **kwargs)
strategy.attach_replay_id(replay_id)
logger.info(f"{strategy=}")

if not crud.acquire_db_lock():
logger.error("Failed to acquire lock")
return

handler = None
rval = True
if capture:
Expand Down Expand Up @@ -113,6 +123,8 @@ def replay(
_capture.stop()
logger.remove(handler)

crud.release_db_lock()

return rval


Expand Down
48 changes: 48 additions & 0 deletions openadapt/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np

from openadapt import models, playback, utils
from openadapt.db import crud

MAX_FRAME_TIMES = 1000

Expand Down Expand Up @@ -49,20 +50,46 @@ def get_next_action_event(
"""
pass

def attach_replay_id(self, replay_id: int) -> None:
"""Attach the replay ID to the strategy.

Args:
replay_id (int): The replay ID.
"""
self._replay_id = replay_id

def run(self) -> None:
"""Run the replay strategy."""
keyboard_controller = keyboard.Controller()
mouse_controller = mouse.Controller()
while True:
screenshot = models.Screenshot.take_screenshot()
crud.add_replay_log(
replay_id=self._replay_id,
log_level="INFO",
key="screenshot",
data=screenshot.png_data,
)
self.screenshots.append(screenshot)
window_event = models.WindowEvent.get_active_window_event()
crud.add_replay_log(
replay_id=self._replay_id,
log_level="INFO",
key="window_event",
data=window_event,
)
self.window_events.append(window_event)
try:
action_event = self.get_next_action_event(
screenshot,
window_event,
)
crud.add_replay_log(
replay_id=self._replay_id,
log_level="INFO",
key="action_event",
data=action_event.to_log_dict(),
)
except StopIteration:
break
if self.action_events:
Expand All @@ -83,13 +110,25 @@ def run(self) -> None:
drop_constant=False,
)[0]
logger.debug(f"action_event=\n{pformat(action_event_dict)}")
crud.add_replay_log(
replay_id=self._replay_id,
log_level="INFO",
key="action_event_dict",
data=action_event_dict,
)
self.action_events.append(action_event)
try:
playback.play_action_event(
action_event,
mouse_controller,
keyboard_controller,
)
crud.add_replay_log(
replay_id=self._replay_id,
log_level="INFO",
key="playback",
data="success",
)
except Exception as exc:
logger.exception(exc)
import ipdb
Expand All @@ -106,5 +145,14 @@ def log_fps(self) -> None:
mean_dt = np.mean(dts)
fps = 1 / mean_dt
logger.info(f"{fps=:.2f}")
crud.add_replay_log(
replay_id=self._replay_id, log_level="INFO", key="fps", data=fps
)
if len(self.frame_times) > self.max_frame_times:
self.frame_times.pop(0)
crud.add_replay_log(
replay_id=self._replay_id,
log_level="INFO",
key="frame_times",
data=self.frame_times,
)
14 changes: 14 additions & 0 deletions openadapt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from loguru import logger
from PIL import Image, ImageEnhance
from posthog import Posthog
import git

from openadapt.build_utils import is_running_from_executable, redirect_stdout_stderr

Expand Down Expand Up @@ -952,5 +953,18 @@ def get_posthog_instance() -> DistinctIDPosthog:
return posthog


def get_git_hash() -> str:
"""Get the Git hash of the current commit."""
git_hash = None
try:
repo = git.Repo(search_parent_directories=True)
git_hash = repo.head.commit.hexsha
except git.InvalidGitRepositoryError:
git_hash = importlib.metadata.version("openadapt")
except Exception as exc:
logger.warning(f"{exc=}")
return git_hash


if __name__ == "__main__":
fire.Fire(get_functions(__name__))
Loading
Loading