From 24cb3bbc57962c64e5acf77baae038d93982cfd8 Mon Sep 17 00:00:00 2001 From: Ligeng Zhu Date: Mon, 4 Mar 2024 23:09:29 +0000 Subject: [PATCH 1/4] support video frames. --- tinychat/serve/gradio_web_server.py | 46 +++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 6 deletions(-) diff --git a/tinychat/serve/gradio_web_server.py b/tinychat/serve/gradio_web_server.py index c5f5dab..29358e0 100644 --- a/tinychat/serve/gradio_web_server.py +++ b/tinychat/serve/gradio_web_server.py @@ -199,18 +199,50 @@ def clear_after_click_example_3_image_icl(imagebox, imagebox_2, imagebox_3, text def add_images( - state, imagebox, imagebox_2, imagebox_3, image_process_mode, request: gr.Request + state, imagebox, imagebox_2, imagebox_3, videobox, image_process_mode, request: gr.Request ): if state.image_loaded: # return (state,) + (None,) * IMAGE_BOX_NUM return state + + def extract_frames(video_path): + import cv2 + from PIL import Image + vidcap = cv2.VideoCapture(video_path) + fps = vidcap.get(cv2.CAP_PROP_FPS) + frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = frame_count / fps + + frame_interval = frame_count // 8 + print(duration, frame_count, frame_interval) + + frame_interval = 10 + + def get_frame(stamp): + frame_id = int(fps * stamp) + vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_id) + ret, frame = vidcap.read() + assert ret, "videocap.read fails!" + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + print(f"loading {stamp} success") + return im_pil + + # return [get_frame(0), get_frame(stamp1), get_frame(stamp2)] + return [get_frame(0), get_frame(frame_interval * 1), ] + + if videobox is not None: + frames = None + frames = extract_frames(videobox) + # add frames as regular images + logger.info(f"Got videobox: {videobox}.") + logger.info(f"add_image. ip: {request.client.host}.") im_count = 0 - for image in [imagebox, imagebox_2, imagebox_3]: + image_list = [imagebox, imagebox_2, imagebox_3, *frames] + for image in image_list: if image is not None: im_count += 1 - for image in [imagebox, imagebox_2, imagebox_3]: - if image is not None: if args.auto_pad_image_token or im_count == 1: text = (AUTO_FILL_IM_TOKEN_HOLDER, image, image_process_mode) else: @@ -222,6 +254,7 @@ def add_images( # state.append_message(state.roles[0], text) # state.append_message(state.roles[1], None) # state.skip_next = False + logger.info(f"im_count {im_count}. ip: {request.client.host}.") state.image_loaded = True # return (state,) + (None,) * IMAGE_BOX_NUM return state @@ -564,6 +597,7 @@ def build_demo(embed_mode): imagebox = gr.Image(type="pil") imagebox_2 = gr.Image(type="pil") imagebox_3 = gr.Image(type="pil") + videobox = gr.Video(label="1 video = 8 frames") image_process_mode = gr.Radio( ["Crop", "Resize", "Pad", "Default"], value="Default", @@ -841,7 +875,7 @@ def build_demo(embed_mode): clear_text_history, [state, prompt_style_btn], [state, chatbot], queue=False ).then( add_images, - [state, imagebox, imagebox_2, imagebox_3, image_process_mode], + [state, imagebox, imagebox_2, imagebox_3, videobox, image_process_mode], [state], queue=False, ).then( @@ -863,7 +897,7 @@ def build_demo(embed_mode): clear_text_history, [state, prompt_style_btn], [state, chatbot], queue=False ).then( add_images, - [state, imagebox, imagebox_2, imagebox_3, image_process_mode], + [state, imagebox, imagebox_2, imagebox_3, videobox, image_process_mode], [state], queue=False, ).then( From d1175c4daf1e0ea8059984a6aeda31d5a70a0d0a Mon Sep 17 00:00:00 2001 From: Ligeng Zhu Date: Tue, 5 Mar 2024 03:08:32 +0000 Subject: [PATCH 2/4] fix prompt feeding --- tinychat/serve/gradio_web_server.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tinychat/serve/gradio_web_server.py b/tinychat/serve/gradio_web_server.py index 29358e0..e827b75 100644 --- a/tinychat/serve/gradio_web_server.py +++ b/tinychat/serve/gradio_web_server.py @@ -213,9 +213,8 @@ def extract_frames(video_path): frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = frame_count / fps - frame_interval = frame_count // 8 + frame_interval = frame_count // 16 print(duration, frame_count, frame_interval) - frame_interval = 10 def get_frame(stamp): @@ -231,18 +230,23 @@ def get_frame(stamp): # return [get_frame(0), get_frame(stamp1), get_frame(stamp2)] return [get_frame(0), get_frame(frame_interval * 1), ] + frames = [None, ] if videobox is not None: - frames = None frames = extract_frames(videobox) # add frames as regular images logger.info(f"Got videobox: {videobox}.") logger.info(f"add_image. ip: {request.client.host}.") - im_count = 0 image_list = [imagebox, imagebox_2, imagebox_3, *frames] + logger.info(f"image_list: {image_list}") + + im_count = 0 for image in image_list: if image is not None: im_count += 1 + + for image in image_list: + if image is not None: if args.auto_pad_image_token or im_count == 1: text = (AUTO_FILL_IM_TOKEN_HOLDER, image, image_process_mode) else: From 1efc17fe91cac88c3d881bad2727b2525d0a1069 Mon Sep 17 00:00:00 2001 From: Ligeng Zhu Date: Thu, 7 Mar 2024 19:49:19 +0000 Subject: [PATCH 3/4] update --- .gitignore | 3 +++ tinychat/serve/gradio_web_server.py | 38 ++++++++++++++++++++--------- tinychat/utils/constants.py | 2 +- 3 files changed, 31 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 6417a78..ee1c55b 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,6 @@ cython_debug/ **/*.pyc *.json __pycache__ +controller.log.* +tinychat/serve/gradio_web_server.log.* +tinychat/serve/model_worker_* diff --git a/tinychat/serve/gradio_web_server.py b/tinychat/serve/gradio_web_server.py index e827b75..383be89 100644 --- a/tinychat/serve/gradio_web_server.py +++ b/tinychat/serve/gradio_web_server.py @@ -214,21 +214,37 @@ def extract_frames(video_path): duration = frame_count / fps frame_interval = frame_count // 16 - print(duration, frame_count, frame_interval) + print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval) frame_interval = 10 - def get_frame(stamp): - frame_id = int(fps * stamp) - vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_id) - ret, frame = vidcap.read() - assert ret, "videocap.read fails!" - img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - im_pil = Image.fromarray(img) - print(f"loading {stamp} success") - return im_pil + def get_frame(max_frames=6): + # frame_id = int(fps * stamp) + # vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_id) + # ret, frame = vidcap.read() + images = [] + count = 0 + success = True + while success: + success, frame = vidcap.read() + if count % frame_interval: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + if len(images) == max_frames: + return images + + count += 1 + # assert ret, "videocap.read fails!" + # img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + # im_pil = Image.fromarray(img) + # print(f"loading {stamp} success") + return images # return [get_frame(0), get_frame(stamp1), get_frame(stamp2)] - return [get_frame(0), get_frame(frame_interval * 1), ] + # img = get_frame(0) + # img1 = get_frame(frame_interval * 1) + # return [img, img1, img, img1, img, img1,] + return get_frame(6) frames = [None, ] if videobox is not None: diff --git a/tinychat/utils/constants.py b/tinychat/utils/constants.py index 6920e88..350b66b 100644 --- a/tinychat/utils/constants.py +++ b/tinychat/utils/constants.py @@ -3,7 +3,7 @@ def init(): global max_seq_len, max_batch_size, llama_multiple_of, mem_efficient_load - max_seq_len = 2048 + max_seq_len = 5120 max_batch_size = 1 llama_multiple_of = 256 mem_efficient_load = False # Whether to load the checkpoint in a layer-wise manner. Activate this if you are facing OOM issues on edge devices (e.g., Jetson Orin). From a5bd6f653cc4ee942be570d5e0241e19f2c7a264 Mon Sep 17 00:00:00 2001 From: Ligeng Zhu Date: Thu, 7 Mar 2024 20:10:04 +0000 Subject: [PATCH 4/4] update --- tinychat/serve/gradio_web_server.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tinychat/serve/gradio_web_server.py b/tinychat/serve/gradio_web_server.py index 383be89..e7230d5 100644 --- a/tinychat/serve/gradio_web_server.py +++ b/tinychat/serve/gradio_web_server.py @@ -213,11 +213,11 @@ def extract_frames(video_path): frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = frame_count / fps - frame_interval = frame_count // 16 + frame_interval = frame_count // 10 print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval) - frame_interval = 10 + # frame_interval = 10 - def get_frame(max_frames=6): + def get_frame(max_frames): # frame_id = int(fps * stamp) # vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_id) # ret, frame = vidcap.read() @@ -244,7 +244,7 @@ def get_frame(max_frames=6): # img = get_frame(0) # img1 = get_frame(frame_interval * 1) # return [img, img1, img, img1, img, img1,] - return get_frame(6) + return get_frame(8) frames = [None, ] if videobox is not None: