Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batch inference becomes slower and slower #63

Open
fushi219 opened this issue Nov 4, 2024 · 2 comments
Open

batch inference becomes slower and slower #63

fushi219 opened this issue Nov 4, 2024 · 2 comments

Comments

@fushi219
Copy link

fushi219 commented Nov 4, 2024

Hi! Thanks for your job!
I do batch inference(~1000 images) on a V100 gpu.
At first, the inference is quick(~1s). However, for rear images, it becomes slower and slower(~30s for 300th image.)
The memory-usage seems to be around 7.5G.
Is there any possible reasons for this problem?

@zhumorui
Copy link

I saw this too, this is a wired thing.

@zhumorui
Copy link

@fushi219
After I simplified the run.py in https://github.com/apple/ml-depth-pro/blob/main/src/depth_pro/cli/run.py, there is no speed loss any more.

#!/usr/bin/env python3
"""Simplified script to run DepthPro inference and save results.

Copyright (C) 2024 Apple Inc. All Rights Reserved.
"""

import argparse
import logging
from pathlib import Path

import numpy as np
import PIL.Image
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt


from depth_pro import create_model_and_transforms, load_rgb

LOGGER = logging.getLogger(__name__)

def get_torch_device() -> torch.device:
    """Get the Torch device."""
    if torch.cuda.is_available():
        return torch.device("cuda:0")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")

def run(args):
    """Run DepthPro on images and save results."""
    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    # Load model.
    model, transform = create_model_and_transforms(
        device=get_torch_device(),
        precision=torch.half,
    )
    model.eval()

    image_paths = [args.image_path]
    if args.image_path.is_dir():
        image_paths = args.image_path.glob("**/*")
        relative_path = args.image_path
    else:
        relative_path = args.image_path.parent

    for image_path in tqdm(image_paths):
        try:
            # Skip if outputs already exist.
            output_file = (
                args.output_path
                / image_path.relative_to(relative_path).parent
                / image_path.stem
            )
            if (output_file.with_suffix(".npz").exists() and
                output_file.with_suffix(".jpg").exists()):
                LOGGER.info(f"Outputs exist for {image_path}, skipping.")
                continue

            LOGGER.info(f"Processing image: {image_path}")
            image, _, f_px = load_rgb(image_path)

            # Run inference.
            prediction = model.infer(transform(image), f_px=f_px)
            depth = prediction["depth"].detach().cpu().numpy().squeeze()

            # Save Depth as npz file.
            LOGGER.info(f"Saving depth map to: {output_file}.npz")
            output_file.parent.mkdir(parents=True, exist_ok=True)
            np.savez_compressed(output_file, depth=depth)

            # Save as color-mapped "turbo" jpg image.
            inverse_depth = 1 / depth
            max_invdepth_vizu = min(inverse_depth.max(), 1 / 0.1)
            min_invdepth_vizu = max(1 / 250, inverse_depth.min())
            inverse_depth_normalized = (inverse_depth - min_invdepth_vizu) / (
                max_invdepth_vizu - min_invdepth_vizu
            )
            cmap = plt.get_cmap("turbo")
            color_depth = (cmap(inverse_depth_normalized)[..., :3] * 255).astype(np.uint8)
            color_map_output_file = str(output_file) + ".jpg"
            LOGGER.info(f"Saving color-mapped depth to: {color_map_output_file}")
            PIL.Image.fromarray(color_depth).save(
                color_map_output_file, format="JPEG", quality=90
            )

        except Exception as e:
            LOGGER.error(f"Error processing {image_path}: {str(e)}")

    LOGGER.info("All images processed.")

def main():
    """Run simplified DepthPro inference."""
    parser = argparse.ArgumentParser(
        description="Simplified inference script for DepthPro."
    )
    parser.add_argument(
        "-i", 
        "--image-path", 
        type=Path, 
        required=True,
        help="Path to input image or directory."
    )
    parser.add_argument(
        "-o",
        "--output-path",
        type=Path,
        required=True,
        help="Path to store output files."
    )
    parser.add_argument(
        "-v", 
        "--verbose", 
        action="store_true", 
        help="Show verbose output."
    )
    
    run(parser.parse_args())

if __name__ == "__main__":
    main()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants