Skip to content

Commit 5eb000e

Browse files
authored
Merge pull request #76 from nateraw/fix-audio-alignment
Fix audio alignment and add basic tests
2 parents 90039cf + 6a949ec commit 5eb000e

File tree

4 files changed

+92
-6
lines changed

4 files changed

+92
-6
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,4 +132,5 @@ dmypy.json
132132
dreams
133133
images
134134
run.py
135-
examples
135+
examples
136+
test_outputs

stable_diffusion_videos/stable_diffusion_pipeline.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ def walk(
654654
if not resume and isinstance(num_interpolation_steps, int):
655655
num_interpolation_steps = [num_interpolation_steps] * (len(prompts) - 1)
656656

657-
if not resume and audio_filepath:
657+
if not resume:
658658
audio_start_sec = audio_start_sec or 0
659659

660660
# Save/reload prompt config
@@ -719,6 +719,9 @@ def walk(
719719
continue
720720
print(f"Resuming {save_path.name} from frame {skip}")
721721

722+
audio_offset = audio_start_sec + sum(num_interpolation_steps[:i]) / fps
723+
audio_duration = num_step / fps
724+
722725
self.generate_interpolation_clip(
723726
prompt_a,
724727
prompt_b,
@@ -736,8 +739,8 @@ def walk(
736739
skip=skip,
737740
T=get_timesteps_arr(
738741
audio_filepath,
739-
offset=audio_start_sec + (i * num_step / fps),
740-
duration=num_step / fps,
742+
offset=audio_offset,
743+
duration=audio_duration,
741744
fps=fps,
742745
margin=(1.0, 5.0),
743746
)
@@ -750,8 +753,8 @@ def walk(
750753
fps=fps,
751754
output_filepath=step_output_filepath,
752755
glob_pattern=f"*{image_file_ext}",
753-
audio_offset=audio_start_sec + (i * num_step / fps) if audio_start_sec else 0,
754-
audio_duration=num_step / fps,
756+
audio_offset=audio_offset,
757+
audio_duration=audio_duration,
755758
sr=44100,
756759
)
757760

tests/samples/choice.wav

431 KB
Binary file not shown.

tests/test_pipeline.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Tests require GPU, so they will not be running on CI (unless someone
2+
wants to figure that out for me).
3+
4+
We'll run these locally before pushing to the repo, or at the very least
5+
before making a release.
6+
"""
7+
8+
from stable_diffusion_videos import NoCheck, StableDiffusionWalkPipeline
9+
import torch
10+
from pathlib import Path
11+
from shutil import rmtree
12+
13+
import pytest
14+
15+
16+
TEST_OUTPUT_ROOT = "test_outputs"
17+
SAMPLES_DIR = Path(__file__).parent / "samples"
18+
19+
@pytest.fixture
20+
def pipeline(scope="session"):
21+
pipe = StableDiffusionWalkPipeline.from_pretrained(
22+
"CompVis/stable-diffusion-v1-4",
23+
use_auth_token=True,
24+
torch_dtype=torch.float16,
25+
revision="fp16",
26+
).to('cuda')
27+
pipe.safety_checker = NoCheck().cuda()
28+
return pipe
29+
30+
31+
@pytest.fixture
32+
def run_name(request):
33+
fn_name = request.node.name.lstrip('test_')
34+
output_path = Path(TEST_OUTPUT_ROOT) / fn_name
35+
if output_path.exists():
36+
rmtree(output_path)
37+
# We could instead yield here and rm the dir after its written.
38+
# However, I like being able to view the files locally to see if they look right.
39+
return fn_name
40+
41+
42+
def test_walk_basic(pipeline, run_name):
43+
video_path = pipeline.walk(
44+
['a cat', 'a dog', 'a horse'],
45+
seeds=[42, 1337, 2022],
46+
num_interpolation_steps=[3, 3],
47+
output_dir=TEST_OUTPUT_ROOT,
48+
name=run_name,
49+
fps=3,
50+
)
51+
assert Path(video_path).exists(), "Video file was not created"
52+
53+
54+
def test_walk_with_audio(pipeline, run_name):
55+
fps = 6
56+
audio_offsets = [2, 4, 5, 8]
57+
num_interpolation_steps = [(b - a) * fps for a, b in zip(audio_offsets, audio_offsets[1:])]
58+
video_path = pipeline.walk(
59+
['a cat', 'a dog', 'a horse', 'a cow'],
60+
seeds=[42, 1337, 4321, 1234],
61+
num_interpolation_steps=num_interpolation_steps,
62+
output_dir=TEST_OUTPUT_ROOT,
63+
name=run_name,
64+
fps=fps,
65+
audio_filepath=str(Path(SAMPLES_DIR) / 'choice.wav'),
66+
audio_start_sec=audio_offsets[0],
67+
batch_size=16,
68+
)
69+
assert Path(video_path).exists(), "Video file was not created"
70+
71+
72+
def test_walk_with_upsampler(pipeline, run_name):
73+
video_path = pipeline.walk(
74+
['a cat', 'a dog', 'a horse'],
75+
seeds=[42, 1337, 2022],
76+
num_interpolation_steps=[3, 3],
77+
output_dir=TEST_OUTPUT_ROOT,
78+
name=run_name,
79+
fps=3,
80+
upsample=True,
81+
)
82+
assert Path(video_path).exists(), "Video file was not created"

0 commit comments

Comments
 (0)