Skip to content

Audio narration #673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 4, 2024
53 changes: 53 additions & 0 deletions openadapt/alembic/versions/98c8851a5321_add_audio_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""add_audio_info

Revision ID: 98c8851a5321
Revises: d714cc86fce8
Create Date: 2024-05-29 16:56:25.832333

"""
from alembic import op
import sqlalchemy as sa

import openadapt

# revision identifiers, used by Alembic.
revision = "98c8851a5321"
down_revision = "d714cc86fce8"
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"audio_info",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"timestamp",
openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False),
nullable=True,
),
sa.Column("flac_data", sa.LargeBinary(), nullable=True),
sa.Column("transcribed_text", sa.String(), nullable=True),
sa.Column(
"recording_timestamp",
openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False),
nullable=True,
),
sa.Column("recording_id", sa.Integer(), nullable=True),
sa.Column("sample_rate", sa.Integer(), nullable=True),
sa.Column("words_with_timestamps", sa.Text(), nullable=True),
sa.ForeignKeyConstraint(
["recording_id"],
["recording.id"],
name=op.f("fk_audio_info_recording_id_recording"),
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_audio_info")),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("audio_info")
# ### end Alembic commands ###
31 changes: 30 additions & 1 deletion openadapt/app/dashboard/api/recordings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""API endpoints for recordings."""

import json

from fastapi import APIRouter, WebSocket
from loguru import logger

Expand Down Expand Up @@ -80,6 +82,22 @@ async def get_recording_detail(websocket: WebSocket, recording_id: int) -> None:
{"type": "num_events", "value": len(action_events)}
)

try:
# TODO: change to use recording_id once scrubbing PR is merged
audio_info = crud.get_audio_info(session, recording.timestamp)[0]
words_with_timestamps = json.loads(audio_info.words_with_timestamps)
words_with_timestamps = [
{
"word": word["word"],
"start": word["start"] + action_events[0].timestamp,
"end": word["end"] + action_events[0].timestamp,
}
for word in words_with_timestamps
]
except IndexError:
words_with_timestamps = []
word_index = 0

def convert_to_str(event_dict: dict) -> dict:
"""Convert the keys to strings."""
if "key" in event_dict:
Expand All @@ -104,7 +122,18 @@ def convert_to_str(event_dict: dict) -> dict:
width, height = 0, 0
event_dict["screenshot"] = image
event_dict["dimensions"] = {"width": width, "height": height}

words = []
# each word in words_with_timestamp is a dict of word, start, end
# we want to add the word to the event_dict if the start is
# before the event timestamp
while (
word_index < len(words_with_timestamps)
and words_with_timestamps[word_index]["start"]
< event_dict["timestamp"]
):
words.append(words_with_timestamps[word_index]["word"])
word_index += 1
event_dict["words"] = words
convert_to_str(event_dict)
await websocket.send_json({"type": "action_event", "value": event_dict})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ export const ActionEvent = ({
<TableCellWithBorder>{event.parent_id}</TableCellWithBorder>
</TableRowWithBorder>
)}
{event.words && event.words.length > 0 && (
<TableRowWithBorder>
<TableCellWithBorder>transcription</TableCellWithBorder>
<TableCellWithBorder>{event.words.join(' ')}</TableCellWithBorder>
</TableRowWithBorder>
)}
<TableRowWithBorder>
<TableCellWithBorder>children</TableCellWithBorder>
<TableCellWithBorder>{event.children?.length || 0}</TableCellWithBorder>
Expand Down
1 change: 1 addition & 0 deletions openadapt/app/dashboard/types/action-event.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ export type ActionEvent = {
mask: string | null;
dimensions?: { width: number, height: number };
children?: ActionEvent[];
words?: string[];
}
4 changes: 0 additions & 4 deletions openadapt/app/tray.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ def __init__(self) -> None:

self.app.setQuitOnLastWindowClosed(False)

# since the lock is a file, delete it when starting the app so that
# new instances can start even if the previous one crashed
crud.release_db_lock(raise_exception=False)

# currently required for pyqttoast
# TODO: remove once https://github.com/niklashenning/pyqt-toast/issues/9
# is addressed
Expand Down
1 change: 1 addition & 0 deletions openadapt/config.defaults.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"RECORD_READ_ACTIVE_ELEMENT_STATE": false,
"REPLAY_STRIP_ELEMENT_STATE": true,
"RECORD_VIDEO": true,
"RECORD_AUDIO": true,
"RECORD_FULL_VIDEO": false,
"RECORD_IMAGES": false,
"LOG_MEMORY": false,
Expand Down
2 changes: 2 additions & 0 deletions openadapt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PERFORMANCE_PLOTS_DIR_PATH = (DATA_DIR_PATH / "performance").absolute()
CAPTURE_DIR_PATH = (DATA_DIR_PATH / "captures").absolute()
VIDEO_DIR_PATH = DATA_DIR_PATH / "videos"
DATABASE_LOCK_FILE_PATH = DATA_DIR_PATH / "openadapt.db.lock"

STOP_STRS = [
"oa.stop",
Expand Down Expand Up @@ -136,6 +137,7 @@ class SegmentationAdapter(str, Enum):
RECORD_WINDOW_DATA: bool = False
RECORD_READ_ACTIVE_ELEMENT_STATE: bool = False
RECORD_VIDEO: bool
RECORD_AUDIO: bool
# if false, only write video events corresponding to screenshots
RECORD_FULL_VIDEO: bool
RECORD_IMAGES: bool
Expand Down
70 changes: 64 additions & 6 deletions openadapt/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@

from loguru import logger
from sqlalchemy.orm import Session as SaSession
import psutil
import sqlalchemy as sa

from openadapt import utils
from openadapt.config import DATA_DIR_PATH, config
from openadapt.config import DATABASE_LOCK_FILE_PATH, config
from openadapt.db.db import Session, get_read_only_session_maker
from openadapt.models import (
ActionEvent,
AudioInfo,
MemoryStat,
PerformanceStat,
Recording,
Expand Down Expand Up @@ -618,6 +620,56 @@ def update_video_start_time(
)


def insert_audio_info(
session: SaSession,
audio_data: bytes,
transcribed_text: str,
recording: Recording,
timestamp: float,
sample_rate: int,
word_list: list,
) -> None:
"""Create an AudioInfo entry in the database.

Args:
session (sa.orm.Session): The database session.
audio_data (bytes): The audio data.
transcribed_text (str): The transcribed text.
recording (Recording): The recording object.
timestamp (float): The timestamp of the audio.
sample_rate (int): The sample rate of the audio.
word_list (list): A list of words with timestamps.
"""
audio_info = AudioInfo(
flac_data=audio_data,
transcribed_text=transcribed_text,
recording_timestamp=recording.timestamp,
recording_id=recording.id,
timestamp=timestamp,
sample_rate=sample_rate,
words_with_timestamps=json.dumps(word_list),
)
session.add(audio_info)
session.commit()


# TODO: change to use recording_id once scrubbing PR is merged
def get_audio_info(
session: SaSession,
recording_timestamp: float,
) -> list[AudioInfo]:
"""Get the audio info for a given recording.

Args:
session (sa.orm.Session): The database session.
recording_timestamp (float): The timestamp of the recording.

Returns:
list[AudioInfo]: A list of audio info for the recording.
"""
return _get(session, AudioInfo, recording_timestamp)


def post_process_events(session: SaSession, recording: Recording) -> None:
"""Post-process events.

Expand Down Expand Up @@ -764,11 +816,17 @@ def acquire_db_lock(timeout: int = 60) -> bool:
if timeout > 0 and time.time() - start > timeout:
logger.error("Failed to acquire database lock.")
return False
if os.path.exists(DATA_DIR_PATH / "database.lock"):
logger.info("Database is locked. Waiting...")
time.sleep(1)
if os.path.exists(DATABASE_LOCK_FILE_PATH):
with open(DATABASE_LOCK_FILE_PATH, "r") as lock_file:
lock_info = json.load(lock_file)
# check if the process is still running
if psutil.pid_exists(lock_info["pid"]):
logger.info("Database is locked. Waiting...")
time.sleep(1)
else:
release_db_lock(raise_exception=False)
else:
with open(DATA_DIR_PATH / "database.lock", "w") as lock_file:
with open(DATABASE_LOCK_FILE_PATH, "w") as lock_file:
lock_file.write(json.dumps({"pid": os.getpid(), "time": time.time()}))
logger.info("Database lock acquired.")
break
Expand All @@ -778,7 +836,7 @@ def acquire_db_lock(timeout: int = 60) -> bool:
def release_db_lock(raise_exception: bool = True) -> None:
"""Release the database lock."""
try:
os.remove(DATA_DIR_PATH / "database.lock")
os.remove(DATABASE_LOCK_FILE_PATH)
except Exception as e:
if raise_exception:
logger.error("Failed to release database lock.")
Expand Down
18 changes: 18 additions & 0 deletions openadapt/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class Recording(db.Base):
"ScrubbedRecording",
back_populates="recording",
)
audio_info = sa.orm.relationship("AudioInfo", back_populates="recording")

_processed_action_events = None

Expand Down Expand Up @@ -723,6 +724,23 @@ def convert_png_to_binary(self, image: Image.Image) -> bytes:
return buffer.getvalue()


class AudioInfo(db.Base):
"""Class representing the audio from a recording in the database."""

__tablename__ = "audio_info"

id = sa.Column(sa.Integer, primary_key=True)
timestamp = sa.Column(ForceFloat)
flac_data = sa.Column(sa.LargeBinary)
transcribed_text = sa.Column(sa.String)
recording_timestamp = sa.Column(ForceFloat)
recording_id = sa.Column(sa.ForeignKey("recording.id"))
sample_rate = sa.Column(sa.Integer)
words_with_timestamps = sa.Column(sa.Text)

recording = sa.orm.relationship("Recording", back_populates="audio_info")


class PerformanceStat(db.Base):
"""Class representing a performance statistic in the database."""

Expand Down
Loading
Loading