diff --git a/.gitignore b/.gitignore index 6417a78..ee1c55b 100644 --- a/.gitignore +++ b/.gitignore @@ -171,3 +171,6 @@ cython_debug/ **/*.pyc *.json __pycache__ +controller.log.* +tinychat/serve/gradio_web_server.log.* +tinychat/serve/model_worker_* diff --git a/tinychat/serve/gradio_web_server.py b/tinychat/serve/gradio_web_server.py index c5f5dab..e7230d5 100644 --- a/tinychat/serve/gradio_web_server.py +++ b/tinychat/serve/gradio_web_server.py @@ -199,17 +199,69 @@ def clear_after_click_example_3_image_icl(imagebox, imagebox_2, imagebox_3, text def add_images( - state, imagebox, imagebox_2, imagebox_3, image_process_mode, request: gr.Request + state, imagebox, imagebox_2, imagebox_3, videobox, image_process_mode, request: gr.Request ): if state.image_loaded: # return (state,) + (None,) * IMAGE_BOX_NUM return state + + def extract_frames(video_path): + import cv2 + from PIL import Image + vidcap = cv2.VideoCapture(video_path) + fps = vidcap.get(cv2.CAP_PROP_FPS) + frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) + duration = frame_count / fps + + frame_interval = frame_count // 10 + print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval) + # frame_interval = 10 + + def get_frame(max_frames): + # frame_id = int(fps * stamp) + # vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_id) + # ret, frame = vidcap.read() + images = [] + count = 0 + success = True + while success: + success, frame = vidcap.read() + if count % frame_interval: + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + im_pil = Image.fromarray(img) + images.append(im_pil) + if len(images) == max_frames: + return images + + count += 1 + # assert ret, "videocap.read fails!" + # img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + # im_pil = Image.fromarray(img) + # print(f"loading {stamp} success") + return images + + # return [get_frame(0), get_frame(stamp1), get_frame(stamp2)] + # img = get_frame(0) + # img1 = get_frame(frame_interval * 1) + # return [img, img1, img, img1, img, img1,] + return get_frame(8) + + frames = [None, ] + if videobox is not None: + frames = extract_frames(videobox) + # add frames as regular images + logger.info(f"Got videobox: {videobox}.") + logger.info(f"add_image. ip: {request.client.host}.") + image_list = [imagebox, imagebox_2, imagebox_3, *frames] + logger.info(f"image_list: {image_list}") + im_count = 0 - for image in [imagebox, imagebox_2, imagebox_3]: + for image in image_list: if image is not None: im_count += 1 - for image in [imagebox, imagebox_2, imagebox_3]: + + for image in image_list: if image is not None: if args.auto_pad_image_token or im_count == 1: text = (AUTO_FILL_IM_TOKEN_HOLDER, image, image_process_mode) @@ -222,6 +274,7 @@ def add_images( # state.append_message(state.roles[0], text) # state.append_message(state.roles[1], None) # state.skip_next = False + logger.info(f"im_count {im_count}. ip: {request.client.host}.") state.image_loaded = True # return (state,) + (None,) * IMAGE_BOX_NUM return state @@ -564,6 +617,7 @@ def build_demo(embed_mode): imagebox = gr.Image(type="pil") imagebox_2 = gr.Image(type="pil") imagebox_3 = gr.Image(type="pil") + videobox = gr.Video(label="1 video = 8 frames") image_process_mode = gr.Radio( ["Crop", "Resize", "Pad", "Default"], value="Default", @@ -841,7 +895,7 @@ def build_demo(embed_mode): clear_text_history, [state, prompt_style_btn], [state, chatbot], queue=False ).then( add_images, - [state, imagebox, imagebox_2, imagebox_3, image_process_mode], + [state, imagebox, imagebox_2, imagebox_3, videobox, image_process_mode], [state], queue=False, ).then( @@ -863,7 +917,7 @@ def build_demo(embed_mode): clear_text_history, [state, prompt_style_btn], [state, chatbot], queue=False ).then( add_images, - [state, imagebox, imagebox_2, imagebox_3, image_process_mode], + [state, imagebox, imagebox_2, imagebox_3, videobox, image_process_mode], [state], queue=False, ).then( diff --git a/tinychat/utils/constants.py b/tinychat/utils/constants.py index 6920e88..350b66b 100644 --- a/tinychat/utils/constants.py +++ b/tinychat/utils/constants.py @@ -3,7 +3,7 @@ def init(): global max_seq_len, max_batch_size, llama_multiple_of, mem_efficient_load - max_seq_len = 2048 + max_seq_len = 5120 max_batch_size = 1 llama_multiple_of = 256 mem_efficient_load = False # Whether to load the checkpoint in a layer-wise manner. Activate this if you are facing OOM issues on edge devices (e.g., Jetson Orin).