Skip to content

Commit 1e11906

Browse files
KIRA009abrichr
andauthored
feat: Audio narration (#673)
* feat: Add audio narration feature while recording * feat: Remove implicit scrubbing in display_event function and recursively convert reqd properties to str * feat: Add transcribed text to dashboard visualisation * feat: Use recording id as foreign key, and add interrupt signal handler in audio recording process * feat: Check if the lock is stale when acquiring locks * refactor: Convert database lock path to a constant in config file --------- Co-authored-by: Richard Abrich <[email protected]>
1 parent 8b4d9ef commit 1e11906

File tree

14 files changed

+514
-117
lines changed

14 files changed

+514
-117
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""add_audio_info
2+
3+
Revision ID: 98c8851a5321
4+
Revises: d714cc86fce8
5+
Create Date: 2024-05-29 16:56:25.832333
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
import openadapt
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "98c8851a5321"
15+
down_revision = "d714cc86fce8"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade() -> None:
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.create_table(
23+
"audio_info",
24+
sa.Column("id", sa.Integer(), nullable=False),
25+
sa.Column(
26+
"timestamp",
27+
openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False),
28+
nullable=True,
29+
),
30+
sa.Column("flac_data", sa.LargeBinary(), nullable=True),
31+
sa.Column("transcribed_text", sa.String(), nullable=True),
32+
sa.Column(
33+
"recording_timestamp",
34+
openadapt.models.ForceFloat(precision=10, scale=2, asdecimal=False),
35+
nullable=True,
36+
),
37+
sa.Column("recording_id", sa.Integer(), nullable=True),
38+
sa.Column("sample_rate", sa.Integer(), nullable=True),
39+
sa.Column("words_with_timestamps", sa.Text(), nullable=True),
40+
sa.ForeignKeyConstraint(
41+
["recording_id"],
42+
["recording.id"],
43+
name=op.f("fk_audio_info_recording_id_recording"),
44+
),
45+
sa.PrimaryKeyConstraint("id", name=op.f("pk_audio_info")),
46+
)
47+
# ### end Alembic commands ###
48+
49+
50+
def downgrade() -> None:
51+
# ### commands auto generated by Alembic - please adjust! ###
52+
op.drop_table("audio_info")
53+
# ### end Alembic commands ###

openadapt/app/dashboard/api/recordings.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""API endpoints for recordings."""
22

3+
import json
4+
35
from fastapi import APIRouter, WebSocket
46
from loguru import logger
57

@@ -80,6 +82,22 @@ async def get_recording_detail(websocket: WebSocket, recording_id: int) -> None:
8082
{"type": "num_events", "value": len(action_events)}
8183
)
8284

85+
try:
86+
# TODO: change to use recording_id once scrubbing PR is merged
87+
audio_info = crud.get_audio_info(session, recording.timestamp)[0]
88+
words_with_timestamps = json.loads(audio_info.words_with_timestamps)
89+
words_with_timestamps = [
90+
{
91+
"word": word["word"],
92+
"start": word["start"] + action_events[0].timestamp,
93+
"end": word["end"] + action_events[0].timestamp,
94+
}
95+
for word in words_with_timestamps
96+
]
97+
except IndexError:
98+
words_with_timestamps = []
99+
word_index = 0
100+
83101
def convert_to_str(event_dict: dict) -> dict:
84102
"""Convert the keys to strings."""
85103
if "key" in event_dict:
@@ -104,7 +122,18 @@ def convert_to_str(event_dict: dict) -> dict:
104122
width, height = 0, 0
105123
event_dict["screenshot"] = image
106124
event_dict["dimensions"] = {"width": width, "height": height}
107-
125+
words = []
126+
# each word in words_with_timestamp is a dict of word, start, end
127+
# we want to add the word to the event_dict if the start is
128+
# before the event timestamp
129+
while (
130+
word_index < len(words_with_timestamps)
131+
and words_with_timestamps[word_index]["start"]
132+
< event_dict["timestamp"]
133+
):
134+
words.append(words_with_timestamps[word_index]["word"])
135+
word_index += 1
136+
event_dict["words"] = words
108137
convert_to_str(event_dict)
109138
await websocket.send_json({"type": "action_event", "value": event_dict})
110139

openadapt/app/dashboard/components/ActionEvent/ActionEvent.tsx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ export const ActionEvent = ({
122122
<TableCellWithBorder>{event.parent_id}</TableCellWithBorder>
123123
</TableRowWithBorder>
124124
)}
125+
{event.words && event.words.length > 0 && (
126+
<TableRowWithBorder>
127+
<TableCellWithBorder>transcription</TableCellWithBorder>
128+
<TableCellWithBorder>{event.words.join(' ')}</TableCellWithBorder>
129+
</TableRowWithBorder>
130+
)}
125131
<TableRowWithBorder>
126132
<TableCellWithBorder>children</TableCellWithBorder>
127133
<TableCellWithBorder>{event.children?.length || 0}</TableCellWithBorder>

openadapt/app/dashboard/types/action-event.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,5 @@ export type ActionEvent = {
2626
mask: string | null;
2727
dimensions?: { width: number, height: number };
2828
children?: ActionEvent[];
29+
words?: string[];
2930
}

openadapt/app/tray.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,6 @@ def __init__(self) -> None:
7676

7777
self.app.setQuitOnLastWindowClosed(False)
7878

79-
# since the lock is a file, delete it when starting the app so that
80-
# new instances can start even if the previous one crashed
81-
crud.release_db_lock(raise_exception=False)
82-
8379
# currently required for pyqttoast
8480
# TODO: remove once https://github.com/niklashenning/pyqt-toast/issues/9
8581
# is addressed

openadapt/config.defaults.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"RECORD_READ_ACTIVE_ELEMENT_STATE": false,
2020
"REPLAY_STRIP_ELEMENT_STATE": true,
2121
"RECORD_VIDEO": true,
22+
"RECORD_AUDIO": true,
2223
"RECORD_FULL_VIDEO": false,
2324
"RECORD_IMAGES": false,
2425
"LOG_MEMORY": false,

openadapt/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
PERFORMANCE_PLOTS_DIR_PATH = (DATA_DIR_PATH / "performance").absolute()
3030
CAPTURE_DIR_PATH = (DATA_DIR_PATH / "captures").absolute()
3131
VIDEO_DIR_PATH = DATA_DIR_PATH / "videos"
32+
DATABASE_LOCK_FILE_PATH = DATA_DIR_PATH / "openadapt.db.lock"
3233

3334
STOP_STRS = [
3435
"oa.stop",
@@ -136,6 +137,7 @@ class SegmentationAdapter(str, Enum):
136137
RECORD_WINDOW_DATA: bool = False
137138
RECORD_READ_ACTIVE_ELEMENT_STATE: bool = False
138139
RECORD_VIDEO: bool
140+
RECORD_AUDIO: bool
139141
# if false, only write video events corresponding to screenshots
140142
RECORD_FULL_VIDEO: bool
141143
RECORD_IMAGES: bool

openadapt/db/crud.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@
1111

1212
from loguru import logger
1313
from sqlalchemy.orm import Session as SaSession
14+
import psutil
1415
import sqlalchemy as sa
1516

1617
from openadapt import utils
17-
from openadapt.config import DATA_DIR_PATH, config
18+
from openadapt.config import DATABASE_LOCK_FILE_PATH, config
1819
from openadapt.db.db import Session, get_read_only_session_maker
1920
from openadapt.models import (
2021
ActionEvent,
22+
AudioInfo,
2123
MemoryStat,
2224
PerformanceStat,
2325
Recording,
@@ -618,6 +620,56 @@ def update_video_start_time(
618620
)
619621

620622

623+
def insert_audio_info(
624+
session: SaSession,
625+
audio_data: bytes,
626+
transcribed_text: str,
627+
recording: Recording,
628+
timestamp: float,
629+
sample_rate: int,
630+
word_list: list,
631+
) -> None:
632+
"""Create an AudioInfo entry in the database.
633+
634+
Args:
635+
session (sa.orm.Session): The database session.
636+
audio_data (bytes): The audio data.
637+
transcribed_text (str): The transcribed text.
638+
recording (Recording): The recording object.
639+
timestamp (float): The timestamp of the audio.
640+
sample_rate (int): The sample rate of the audio.
641+
word_list (list): A list of words with timestamps.
642+
"""
643+
audio_info = AudioInfo(
644+
flac_data=audio_data,
645+
transcribed_text=transcribed_text,
646+
recording_timestamp=recording.timestamp,
647+
recording_id=recording.id,
648+
timestamp=timestamp,
649+
sample_rate=sample_rate,
650+
words_with_timestamps=json.dumps(word_list),
651+
)
652+
session.add(audio_info)
653+
session.commit()
654+
655+
656+
# TODO: change to use recording_id once scrubbing PR is merged
657+
def get_audio_info(
658+
session: SaSession,
659+
recording_timestamp: float,
660+
) -> list[AudioInfo]:
661+
"""Get the audio info for a given recording.
662+
663+
Args:
664+
session (sa.orm.Session): The database session.
665+
recording_timestamp (float): The timestamp of the recording.
666+
667+
Returns:
668+
list[AudioInfo]: A list of audio info for the recording.
669+
"""
670+
return _get(session, AudioInfo, recording_timestamp)
671+
672+
621673
def post_process_events(session: SaSession, recording: Recording) -> None:
622674
"""Post-process events.
623675
@@ -764,11 +816,17 @@ def acquire_db_lock(timeout: int = 60) -> bool:
764816
if timeout > 0 and time.time() - start > timeout:
765817
logger.error("Failed to acquire database lock.")
766818
return False
767-
if os.path.exists(DATA_DIR_PATH / "database.lock"):
768-
logger.info("Database is locked. Waiting...")
769-
time.sleep(1)
819+
if os.path.exists(DATABASE_LOCK_FILE_PATH):
820+
with open(DATABASE_LOCK_FILE_PATH, "r") as lock_file:
821+
lock_info = json.load(lock_file)
822+
# check if the process is still running
823+
if psutil.pid_exists(lock_info["pid"]):
824+
logger.info("Database is locked. Waiting...")
825+
time.sleep(1)
826+
else:
827+
release_db_lock(raise_exception=False)
770828
else:
771-
with open(DATA_DIR_PATH / "database.lock", "w") as lock_file:
829+
with open(DATABASE_LOCK_FILE_PATH, "w") as lock_file:
772830
lock_file.write(json.dumps({"pid": os.getpid(), "time": time.time()}))
773831
logger.info("Database lock acquired.")
774832
break
@@ -778,7 +836,7 @@ def acquire_db_lock(timeout: int = 60) -> bool:
778836
def release_db_lock(raise_exception: bool = True) -> None:
779837
"""Release the database lock."""
780838
try:
781-
os.remove(DATA_DIR_PATH / "database.lock")
839+
os.remove(DATABASE_LOCK_FILE_PATH)
782840
except Exception as e:
783841
if raise_exception:
784842
logger.error("Failed to release database lock.")

openadapt/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class Recording(db.Base):
8181
"ScrubbedRecording",
8282
back_populates="recording",
8383
)
84+
audio_info = sa.orm.relationship("AudioInfo", back_populates="recording")
8485

8586
_processed_action_events = None
8687

@@ -723,6 +724,23 @@ def convert_png_to_binary(self, image: Image.Image) -> bytes:
723724
return buffer.getvalue()
724725

725726

727+
class AudioInfo(db.Base):
728+
"""Class representing the audio from a recording in the database."""
729+
730+
__tablename__ = "audio_info"
731+
732+
id = sa.Column(sa.Integer, primary_key=True)
733+
timestamp = sa.Column(ForceFloat)
734+
flac_data = sa.Column(sa.LargeBinary)
735+
transcribed_text = sa.Column(sa.String)
736+
recording_timestamp = sa.Column(ForceFloat)
737+
recording_id = sa.Column(sa.ForeignKey("recording.id"))
738+
sample_rate = sa.Column(sa.Integer)
739+
words_with_timestamps = sa.Column(sa.Text)
740+
741+
recording = sa.orm.relationship("Recording", back_populates="audio_info")
742+
743+
726744
class PerformanceStat(db.Base):
727745
"""Class representing a performance statistic in the database."""
728746

0 commit comments

Comments
 (0)