diff --git a/pytorchvideo/data/labeled_video_dataset.py b/pytorchvideo/data/labeled_video_dataset.py index 918eed1..1dc8e73 100644 --- a/pytorchvideo/data/labeled_video_dataset.py +++ b/pytorchvideo/data/labeled_video_dataset.py @@ -4,8 +4,8 @@ import gc import logging -from typing import Any, Callable, Dict, List, Optional, Tuple, Type - +from typing import Any, Callable, Dict, List, Optional, Tuple, Type,Union +import pandas as pd import torch.utils.data from pytorchvideo.data.clip_sampling import ClipSampler from pytorchvideo.data.video import VideoPathHandler @@ -145,6 +145,7 @@ def __next__(self) -> dict: ) self._loaded_video_label = (video, info_dict, video_index) except Exception as e: + print('error is',e)#necessary to print error logger.debug( "Failed to load video with error: {}; trial {}".format( e, @@ -251,7 +252,7 @@ def __iter__(self): def labeled_video_dataset( - data_path: str, + data:Union[str, pd.DataFrame], clip_sampler: ClipSampler, video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler, transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, @@ -292,8 +293,12 @@ def labeled_video_dataset( decoder (str): Defines what type of decoder used to decode a video. """ - labeled_video_paths = LabeledVideoPaths.from_path(data_path) - labeled_video_paths.path_prefix = video_path_prefix + if isinstance(data,pd.DataFrame): + labeled_video_paths= LabeledVideoPaths.from_df(data) + elif isinstance(data,str): + labeled_video_paths = LabeledVideoPaths.from_path(data) + labeled_video_paths.path_prefix = video_path_prefix + dataset = LabeledVideoDataset( labeled_video_paths, clip_sampler, diff --git a/pytorchvideo/data/labeled_video_paths.py b/pytorchvideo/data/labeled_video_paths.py index 8100985..d2ef2a1 100644 --- a/pytorchvideo/data/labeled_video_paths.py +++ b/pytorchvideo/data/labeled_video_paths.py @@ -8,7 +8,7 @@ from iopath.common.file_io import g_pathmgr from torchvision.datasets.folder import make_dataset - +import pandas as pd class LabeledVideoPaths: """ @@ -25,11 +25,14 @@ def from_path(cls, data_path: str) -> LabeledVideoPaths: Args: file_path (str): The path to the file to be read. """ + if g_pathmgr.isfile(data_path): return LabeledVideoPaths.from_csv(data_path) elif g_pathmgr.isdir(data_path): return LabeledVideoPaths.from_directory(data_path) + + else: raise FileNotFoundError(f"{data_path} not found.") @@ -67,6 +70,37 @@ def from_csv(cls, file_path: str) -> LabeledVideoPaths: ), f"Failed to load dataset from {file_path}." return cls(video_paths_and_label) + + @classmethod + def from_df(cls, df:pd.DataFrame) -> LabeledVideoPaths: + """ + Factory function that creates a LabeledVideoPaths object by reading a dataframe. + Sample dataframe + df=pd.DataFrame( + { + "path":["path_to_video_1","path_to_video_2","path_to_video_3"], + "label":["label_1","label_2","label_3"] + }) + + Args: + df (dataframe): The dataframe variable. + """ + video_paths_and_label = [] + for row in df.iterrows(): + row=row[1].values + path=row[0] + label=row[1::].astype(float) + video_paths_and_label.append((path, label)) + + assert ( + len(video_paths_and_label) > 0 + ), f"Failed to load dataset from df." + return cls(video_paths_and_label) + + + + + @classmethod def from_directory(cls, dir_path: str) -> LabeledVideoPaths: """