From ae912d382a6245e7dc7333d0beed689df59b2f53 Mon Sep 17 00:00:00 2001
From: Steely Wing <steely.wing@gmail.com>
Date: Sun, 5 Mar 2023 00:57:01 +0800
Subject: [PATCH 1/2] Improve 8/12FPS anime inference

---
 inference_video.py | 52 +++++++++++++++++++++++-----------------------
 1 file changed, 26 insertions(+), 26 deletions(-)

diff --git a/inference_video.py b/inference_video.py
index c91b805..6e1782c 100644
--- a/inference_video.py
+++ b/inference_video.py
@@ -184,7 +184,13 @@ def make_inference(I0, I1, n):
         else:
             return [*first_half, *second_half]
 
-def pad_image(img):
+def frame_to_image(frame):
+    global device
+    img = (torch
+        .from_numpy(np.transpose(frame, (2,0,1)))
+        .to(device, non_blocking=True)
+        .unsqueeze(0).float() / 255.
+    )
     if(args.fp16):
         return F.pad(img, padding).half()
     else:
@@ -205,43 +211,37 @@ def pad_image(img):
 _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen))
 _thread.start_new_thread(clear_write_buffer, (args, write_buffer))
 
-I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
-I1 = pad_image(I1)
-temp = None # save lastframe when processing static frame
+I1 = frame_to_image(lastframe)
+I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
 
 while True:
-    if temp is not None:
-        frame = temp
-        temp = None
-    else:
-        frame = read_buffer.get()
-    if frame is None:
-        break
     I0 = I1
-    I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
-    I1 = pad_image(I1)
-    I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False)
-    I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
-    ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
+    I0_small = I1_small
 
     break_flag = False
-    if ssim > 0.996:
-        frame = read_buffer.get() # read a new frame
+    # find next key frame (the frame is not same as previous frame)
+    # anime normally use 1 image for 2~3 frames (8/12 FPS) 一拍二 / 一拍三
+    # so max skip frames = 2
+    next_frame = 0
+    while next_frame < 2:
+        next_frame += 1
+        frame = read_buffer.get()
         if frame is None:
             break_flag = True
             frame = lastframe
-        else:
-            temp = frame
-        I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.
-        I1 = pad_image(I1)
-        I1 = model.inference(I0, I1, args.scale)
+            break
+        I1 = frame_to_image(frame)
         I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False)
         ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3])
-        frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w]
+        if (ssim < 0.996):
+            break
         
+    inference_count = (next_frame * args.multi) - 1
+
     if ssim < 0.2:
+        # scene changed, just use previous
         output = []
-        for i in range(args.multi - 1):
+        for i in range(inference_count):
             output.append(I0)
         '''
         output = []
@@ -253,7 +253,7 @@ def pad_image(img):
             output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.)
         '''
     else:
-        output = make_inference(I0, I1, args.multi-1)
+        output = make_inference(I0, I1, inference_count)
 
     if args.montage:
         write_buffer.put(np.concatenate((lastframe, lastframe), 1))

From bb179c84cd11e2f46fc083d1a69c5bff60b63990 Mon Sep 17 00:00:00 2001
From: Steely Wing <steely.wing@gmail.com>
Date: Sun, 5 Mar 2023 12:20:39 +0800
Subject: [PATCH 2/2] Fix progress bar

---
 inference_video.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/inference_video.py b/inference_video.py
index 6e1782c..f7dc07d 100644
--- a/inference_video.py
+++ b/inference_video.py
@@ -15,7 +15,6 @@
 
 def transferAudio(sourceVideo, targetVideo):
     import shutil
-    import moviepy.editor
     tempAudioFileName = "./temp/audio.mkv"
 
     # split audio from original video file and store in "temp" directory
@@ -189,7 +188,7 @@ def frame_to_image(frame):
     img = (torch
         .from_numpy(np.transpose(frame, (2,0,1)))
         .to(device, non_blocking=True)
-        .unsqueeze(0).float() / 255.
+        .unsqueeze(0) / 255.
     )
     if(args.fp16):
         return F.pad(img, padding).half()
@@ -220,7 +219,7 @@ def frame_to_image(frame):
 
     break_flag = False
     # find next key frame (the frame is not same as previous frame)
-    # anime normally use 1 image for 2~3 frames (8/12 FPS) 一拍二 / 一拍三
+    # anime normally use 1 image for 2~3 frames (12/8 FPS) 一拍二 / 一拍三
     # so max skip frames = 2
     next_frame = 0
     while next_frame < 2:
@@ -239,7 +238,7 @@ def frame_to_image(frame):
     inference_count = (next_frame * args.multi) - 1
 
     if ssim < 0.2:
-        # scene changed, just use previous
+        # scene changed, just use previous frame
         output = []
         for i in range(inference_count):
             output.append(I0)
@@ -265,7 +264,7 @@ def frame_to_image(frame):
         for mid in output:
             mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0)))
             write_buffer.put(mid[:h, :w])
-    pbar.update(1)
+    pbar.update(next_frame)
     lastframe = frame
     if break_flag:
         break