From ae912d382a6245e7dc7333d0beed689df59b2f53 Mon Sep 17 00:00:00 2001 From: Steely Wing <steely.wing@gmail.com> Date: Sun, 5 Mar 2023 00:57:01 +0800 Subject: [PATCH 1/2] Improve 8/12FPS anime inference --- inference_video.py | 52 +++++++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/inference_video.py b/inference_video.py index c91b805..6e1782c 100644 --- a/inference_video.py +++ b/inference_video.py @@ -184,7 +184,13 @@ def make_inference(I0, I1, n): else: return [*first_half, *second_half] -def pad_image(img): +def frame_to_image(frame): + global device + img = (torch + .from_numpy(np.transpose(frame, (2,0,1))) + .to(device, non_blocking=True) + .unsqueeze(0).float() / 255. + ) if(args.fp16): return F.pad(img, padding).half() else: @@ -205,43 +211,37 @@ def pad_image(img): _thread.start_new_thread(build_read_buffer, (args, read_buffer, videogen)) _thread.start_new_thread(clear_write_buffer, (args, write_buffer)) -I1 = torch.from_numpy(np.transpose(lastframe, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. -I1 = pad_image(I1) -temp = None # save lastframe when processing static frame +I1 = frame_to_image(lastframe) +I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) while True: - if temp is not None: - frame = temp - temp = None - else: - frame = read_buffer.get() - if frame is None: - break I0 = I1 - I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. - I1 = pad_image(I1) - I0_small = F.interpolate(I0, (32, 32), mode='bilinear', align_corners=False) - I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) - ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) + I0_small = I1_small break_flag = False - if ssim > 0.996: - frame = read_buffer.get() # read a new frame + # find next key frame (the frame is not same as previous frame) + # anime normally use 1 image for 2~3 frames (8/12 FPS) 一拍二 / 一拍三 + # so max skip frames = 2 + next_frame = 0 + while next_frame < 2: + next_frame += 1 + frame = read_buffer.get() if frame is None: break_flag = True frame = lastframe - else: - temp = frame - I1 = torch.from_numpy(np.transpose(frame, (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255. - I1 = pad_image(I1) - I1 = model.inference(I0, I1, args.scale) + break + I1 = frame_to_image(frame) I1_small = F.interpolate(I1, (32, 32), mode='bilinear', align_corners=False) ssim = ssim_matlab(I0_small[:, :3], I1_small[:, :3]) - frame = (I1[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:h, :w] + if (ssim < 0.996): + break + inference_count = (next_frame * args.multi) - 1 + if ssim < 0.2: + # scene changed, just use previous output = [] - for i in range(args.multi - 1): + for i in range(inference_count): output.append(I0) ''' output = [] @@ -253,7 +253,7 @@ def pad_image(img): output.append(torch.from_numpy(np.transpose((cv2.addWeighted(frame[:, :, ::-1], alpha, lastframe[:, :, ::-1], beta, 0)[:, :, ::-1].copy()), (2,0,1))).to(device, non_blocking=True).unsqueeze(0).float() / 255.) ''' else: - output = make_inference(I0, I1, args.multi-1) + output = make_inference(I0, I1, inference_count) if args.montage: write_buffer.put(np.concatenate((lastframe, lastframe), 1)) From bb179c84cd11e2f46fc083d1a69c5bff60b63990 Mon Sep 17 00:00:00 2001 From: Steely Wing <steely.wing@gmail.com> Date: Sun, 5 Mar 2023 12:20:39 +0800 Subject: [PATCH 2/2] Fix progress bar --- inference_video.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/inference_video.py b/inference_video.py index 6e1782c..f7dc07d 100644 --- a/inference_video.py +++ b/inference_video.py @@ -15,7 +15,6 @@ def transferAudio(sourceVideo, targetVideo): import shutil - import moviepy.editor tempAudioFileName = "./temp/audio.mkv" # split audio from original video file and store in "temp" directory @@ -189,7 +188,7 @@ def frame_to_image(frame): img = (torch .from_numpy(np.transpose(frame, (2,0,1))) .to(device, non_blocking=True) - .unsqueeze(0).float() / 255. + .unsqueeze(0) / 255. ) if(args.fp16): return F.pad(img, padding).half() @@ -220,7 +219,7 @@ def frame_to_image(frame): break_flag = False # find next key frame (the frame is not same as previous frame) - # anime normally use 1 image for 2~3 frames (8/12 FPS) 一拍二 / 一拍三 + # anime normally use 1 image for 2~3 frames (12/8 FPS) 一拍二 / 一拍三 # so max skip frames = 2 next_frame = 0 while next_frame < 2: @@ -239,7 +238,7 @@ def frame_to_image(frame): inference_count = (next_frame * args.multi) - 1 if ssim < 0.2: - # scene changed, just use previous + # scene changed, just use previous frame output = [] for i in range(inference_count): output.append(I0) @@ -265,7 +264,7 @@ def frame_to_image(frame): for mid in output: mid = (((mid[0] * 255.).byte().cpu().numpy().transpose(1, 2, 0))) write_buffer.put(mid[:h, :w]) - pbar.update(1) + pbar.update(next_frame) lastframe = frame if break_flag: break