diff --git a/pytorchvideo/accelerator/deployment/common/model_transmuter.py b/pytorchvideo/accelerator/deployment/common/model_transmuter.py index e1593528..ebf74788 100644 --- a/pytorchvideo/accelerator/deployment/common/model_transmuter.py +++ b/pytorchvideo/accelerator/deployment/common/model_transmuter.py @@ -22,15 +22,27 @@ def _find_equivalent_efficient_module( module_name: str = "", ): """ + Searches for an equivalent efficientBlock that can replace the given `module_input` + within the efficient_block_transmuter_list. + Given module_input, search through efficient_block_registry to see whether the module_input can be replaced with equivalent efficientBlock. Returns None if no equivalent efficientBlock is found, else returns an instance of equivalent efficientBlock. + Args: - module_input (nn.Module): module to be replaced by equivalent efficientBlock - efficient_block_transmuter_list (list): a transmuter list that contains transmuter - functions for available efficientBlocks - module_name (str): name of module_input in original model + module_input (nn.Module): The module to be replaced by an equivalent efficientBlock. + efficient_block_transmuter_list (list): A list containing transmuter functions for + available efficientBlocks. + module_name (str): The name of `module_input` in the original model. + + Returns: + nn.Module or None: An instance of the equivalent efficientBlock if found; otherwise, None. + + This function iterates through the `efficient_block_transmuter_list` and applies each transmuter + function to `module_input`. If an equivalent efficientBlock is found, it is added to the + `eq_module_hit_list`. If multiple matches are found, a warning is logged, and the one with + the highest priority is chosen. If no matches are found, None is returned. """ eq_module_hit_list = [] for iter_func in efficient_block_transmuter_list: @@ -56,13 +68,20 @@ def transmute_model( prefix: str = "", ): """ - Recursively goes through user input model and replace module in place with available - equivalent efficientBlock for target device. + Recursively goes through the user input model and replaces modules in place with + equivalent efficientBlocks suitable for the target device. + Args: - model (nn.Module): user input model to be transmuted - target_device (str): name of target device, used to access transmuter list in - EFFICIENT_BLOCK_TRANSMUTER_REGISTRY - prefix (str): name of current hierarchy in user model + model (nn.Module): The user input model to be transmuted. + target_device (str): The name of the target device, used to access the transmuter + list in EFFICIENT_BLOCK_TRANSMUTER_REGISTRY. + prefix (str): The name of the current hierarchy in the user model. + + This function recursively traverses the input `model`, examining each child module. + It attempts to find an equivalent efficientBlock for each module and replaces it + in the model if an equivalent is found. The replacement is logged for reference. + + Note: Make sure the target device is registered in the EFFICIENT_BLOCK_TRANSMUTER_REGISTRY. """ assert ( target_device in EFFICIENT_BLOCK_TRANSMUTER_REGISTRY diff --git a/pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/transmuter_mobile_cpu.py b/pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/transmuter_mobile_cpu.py index dfaee8a8..02c022f4 100644 --- a/pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/transmuter_mobile_cpu.py +++ b/pytorchvideo/accelerator/deployment/mobile_cpu/transmuter/transmuter_mobile_cpu.py @@ -12,11 +12,21 @@ def transmute_Conv3dPwBnAct(input_module: nn.Module): """ - Given an input_module, transmutes it into a equivalent Conv3dPwBnAct. Returns None - if no equivalent Conv3dPwBnAct is found, else returns an instance of equivalent - Conv3dPwBnAct. + Transmutes the given `input_module` into an equivalent Conv3dPwBnAct module if applicable. + Args: - input_module (nn.Module): input module to find an equivalent Conv3dPwBnAct + input_module (nn.Module): The input module to find an equivalent Conv3dPwBnAct for. + + Returns: + Conv3dPwBnAct or None: An instance of the equivalent Conv3dPwBnAct module if found; + otherwise, None. + + This function checks if `input_module` is an instance of nn.Conv3d and if it matches specific + criteria, such as kernel size, groups, stride, padding, and dilation. If the criteria are met, + it creates and returns an equivalent Conv3dPwBnAct module, copying the weights if necessary. + + Note: Conv3dPwBnAct is a module that combines a 3D pointwise convolution with batch normalization + and activation functions. """ if not isinstance(input_module, nn.Conv3d): return None @@ -42,11 +52,22 @@ def transmute_Conv3dPwBnAct(input_module: nn.Module): def transmute_Conv3d3x3x3DwBnAct(input_module: nn.Module): """ - Given an input_module, transmutes it into a equivalent Conv3d3x3x3DwBnAct. Returns - None if no equivalent Conv3d3x3x3DwBnAct is found, else returns an instance of - equivalent Conv3d3x3x3DwBnAct. + Transmutes the given `input_module` into an equivalent Conv3d3x3x3DwBnAct module if applicable. + Args: - input_module (nn.Module): input module to find an equivalent Conv3d3x3x3DwBnAct + input_module (nn.Module): The input module to find an equivalent Conv3d3x3x3DwBnAct for. + + Returns: + Conv3d3x3x3DwBnAct or None: An instance of the equivalent Conv3d3x3x3DwBnAct module if found; + otherwise, None. + + This function checks if `input_module` is an instance of nn.Conv3d and if it matches specific + criteria, such as kernel size, in_channels, groups, stride, padding, padding_mode, and dilation. + If the criteria are met, it creates and returns an equivalent Conv3d3x3x3DwBnAct module, copying + the weights if necessary. + + Note: Conv3d3x3x3DwBnAct is a module that combines a 3D 3x3x3 depthwise convolution with batch + normalization and activation functions. """ if not isinstance(input_module, nn.Conv3d): return None diff --git a/pytorchvideo/accelerator/deployment/mobile_cpu/utils/model_conversion.py b/pytorchvideo/accelerator/deployment/mobile_cpu/utils/model_conversion.py index 2b36f54e..7512a9df 100644 --- a/pytorchvideo/accelerator/deployment/mobile_cpu/utils/model_conversion.py +++ b/pytorchvideo/accelerator/deployment/mobile_cpu/utils/model_conversion.py @@ -17,14 +17,21 @@ def _add_input_tensor_size_lut_hook( base_name: str = "", ) -> None: """ - This helper function recursively goes through all modules in a network, registers - forward hook function to each module. The hook function records the input tensor - size in forward in input_tensor_size_lut[base_name]. + Recursively adds a forward hook to each module in a network, recording input tensor sizes + in the provided input_tensor_size_lut dictionary. + Args: - module (nn.Module): input module to add hook recursively. - input_tensor_size_lut (dict): lut to record input tensor size for hook function. - hook_handle_list (list): a list to contain hook handles. - base_name (str): name for module input. + module (nn.Module): The input module to add hooks to, recursively. + input_tensor_size_lut (dict): A dictionary to record input tensor sizes for the hook function. + hook_handle_list (list): A list to contain hook handles. + base_name (str): The base name for the input module. + + This helper function iterates through the input `module` and its children, registering a forward + hook for each module. The hook function records the input tensor size for the module in the + `input_tensor_size_lut` dictionary using the `base_name` as the key. + + Note: Forward hooks are useful for monitoring and analyzing the input tensor sizes as they pass + through each module in a neural network. """ def hook_fn(_, _in, _out): @@ -51,21 +58,27 @@ def _convert_module( native_conv3d_op_qnnpack: bool = False, ) -> None: """ - This helper function recursively goes through sub-modules in a network. If current - module is a efficient block (instance of EfficientBlockBase) with convert() method, - its convert() method will be called, and the input tensor size (needed by efficient - blocks for mobile cpu) will be provided by matching module name in - input_tensor_size_lut. - Otherwise if the input module is a non efficient block, this function will try to go - through child modules of input module to look for any efficient block in lower - hierarchy. + Recursively traverses sub-modules in a neural network and performs module conversion + if applicable. For efficient blocks (instances of EfficientBlockBase) with a 'convert' + method, it calls the 'convert' method with the input tensor size obtained from the + 'input_tensor_size_lut'. If the module is not an efficient block, it explores child + modules to find efficient blocks in lower hierarchy. + Args: - module (nn.Module): input module for convert. - input_tensor_size_lut (dict): input tensor size look-up table. - base_name (str): module name for input module. - convert_for_quantize (bool): whether this module is intended to be quantized. - native_conv3d_op_qnnpack (bool): whether the QNNPACK version has native int8 - Conv3d. + module (nn.Module): The input module for conversion. + input_tensor_size_lut (dict): A dictionary containing input tensor sizes for reference. + base_name (str): The name of the module. + convert_for_quantize (bool): Whether this module is intended for quantization. + native_conv3d_op_qnnpack (bool): Whether the QNNPACK version has native int8 Conv3d. + + This helper function is designed for recursively exploring a neural network and converting + specific modules, such as efficient blocks. If a module is an instance of EfficientBlockBase + and has a 'convert' method, it calls the 'convert' method with the input tensor size from + the 'input_tensor_size_lut'. If the module is not an efficient block, it continues to explore + its child modules in search of efficient blocks in lower hierarchies. + + Note: Module conversion is a common step in optimizing and adapting neural networks for + specific hardware or use cases, such as mobile CPUs. """ if isinstance(module, EfficientBlockBase): module.convert( @@ -91,19 +104,31 @@ def convert_to_deployable_form( native_conv3d_op_qnnpack: bool = False, ) -> nn.Module: """ - This function takes an input model, and returns a deployable model copy. + Converts an input model into a deployable form and returns a copy of the modified model. + Args: - model (nn.Module): input model for conversion. The model can include a mix of - efficient blocks (instances of EfficientBlockBase) and non efficient blocks. - The efficient blocks will be converted by calling its convert() method, while - other blocks will stay unchanged. - input_tensor (torch.Tensor): input tensor for model. Note current conversion for - deployable form in mobile cpu only works for single input tensor size (i.e., - the future input tensor to converted model should have the same size as - input_tensor specified here). - convert_for_quantize (bool): whether this module is intended to be quantized. - native_conv3d_op_qnnpack (bool): whether the QNNPACK version has native int8 - Conv3d. + model (nn.Module): The input model for conversion. The model can consist of a mix + of efficient blocks (instances of EfficientBlockBase) and non-efficient blocks. + Efficient blocks are converted using their `convert()` method, while other + blocks remain unchanged. + input_tensor (torch.Tensor): The input tensor used for the model. The conversion for + deployable form on mobile CPU is designed for a single input tensor size. The + future input tensor to the converted model should match the size of the + `input_tensor` specified here. + convert_for_quantize (bool): Indicates whether this module is intended for quantization. + native_conv3d_op_qnnpack (bool): Specifies whether the QNNPACK version has native + int8 Conv3d support. + + Returns: + nn.Module: A copy of the input model converted into a deployable form. + + This function prepares the input model for deployment by performing the following steps: + 1. Captures input tensor sizes during forward pass. + 2. Executes a forward pass to record input tensor sizes. + 3. Removes forward hooks used for input tensor size capture. + 4. Creates a deep copy of the input model for conversion. + 5. Converts the copied model by applying the `_convert_module` function. + 6. Returns the converted model suitable for deployment. """ input_tensor_size_lut = {} hook_handle_list = [] diff --git a/pytorchvideo/accelerator/efficient_blocks/efficient_block_base.py b/pytorchvideo/accelerator/efficient_blocks/efficient_block_base.py index 1040218d..1590bde1 100644 --- a/pytorchvideo/accelerator/efficient_blocks/efficient_block_base.py +++ b/pytorchvideo/accelerator/efficient_blocks/efficient_block_base.py @@ -7,29 +7,37 @@ class EfficientBlockBase(nn.Module): """ - PyTorchVideo/accelerator provides a set of efficient blocks - that have optimal efficiency for each target hardware device. + The EfficientBlockBase is the foundation for efficient blocks provided by PyTorchVideo's accelerator. + These efficient blocks are designed for optimal efficiency on various target hardware devices. Each efficient block has two forms: - - original form: this form is for training. When efficient block is instantiated, - it is in this original form. - - deployable form: this form is for deployment. Once the network is ready for - deploy, it can be converted into deployable form for efficient execution - on target hardware. One block is transformed into deployable form by calling - convert() method. By conversion to deployable form, - various optimization (operator fuse, kernel optimization, etc.) are applied. - - EfficientBlockBase is the base class for efficient blocks. - All efficient blocks should inherit this base class - and implement following methods: - - forward(): same as required by nn.Module - - convert(): called to convert block into deployable form + - Original Form: This form is used during training. When an efficient block is instantiated, + it is in its original form. + - Deployable Form: This form is for deployment. Once the network is prepared for deployment, + it can be converted into the deployable form to enable efficient execution on the target hardware. + Conversion to the deployable form involves various optimizations such as operator fusion + and kernel optimization. + + All efficient blocks must inherit from this base class and implement the following methods: + - forward(): This method serves the same purpose as required by nn.Module. + - convert(): Called to transform the block into its deployable form. + + Subclasses of EfficientBlockBase should provide implementations for these methods to tailor + the behavior of the efficient block for specific use cases and target hardware. + + Note: This class is abstract, and its methods must be implemented in derived classes. """ @abstractmethod def convert(self): + """ + Abstract method to convert the efficient block into its deployable form. + """ pass @abstractmethod def forward(self): + """ + Abstract method for the forward pass of the efficient block. + """ pass diff --git a/pytorchvideo/accelerator/efficient_blocks/no_op_convert_block.py b/pytorchvideo/accelerator/efficient_blocks/no_op_convert_block.py index 81ce0aa5..b6a73d8d 100644 --- a/pytorchvideo/accelerator/efficient_blocks/no_op_convert_block.py +++ b/pytorchvideo/accelerator/efficient_blocks/no_op_convert_block.py @@ -7,14 +7,20 @@ class NoOpConvertBlock(EfficientBlockBase): """ - This class provides an interface with EfficientBlockBase for modules that do not - need convert. + A class that provides an interface with EfficientBlockBase for modules that do not + require conversion. + Args: - model (nn.Module): NoOpConvertBlock takes model as input and generate a wrapper - instance of EfficientBlockBase with same functionality as model, with no change - applied when convert() is called. - """ + model (nn.Module): NoOpConvertBlock takes a model as input and generates a wrapper + instance of EfficientBlockBase with the same functionality as the model. When + `convert()` is called on this instance, no changes are applied. + This class is designed for modules that do not need any conversion when integrated into + an EfficientBlockBase. It takes an existing `model` and acts as a pass-through, forwarding + input directly to the underlying model during the `forward` pass. When `convert()` is + called, it simply does nothing, ensuring that no modifications are made to the model. + """ + def __init__(self, model: nn.Module): super().__init__() self.model = model diff --git a/pytorchvideo/data/ava.py b/pytorchvideo/data/ava.py index aed7c5e6..c9d569e4 100644 --- a/pytorchvideo/data/ava.py +++ b/pytorchvideo/data/ava.py @@ -14,10 +14,33 @@ class AvaLabeledVideoFramePaths: """ - Pre-processor for Ava Actions Dataset stored as image frames - - `_` - This class handles the parsing of all the necessary - csv files containing frame paths and frame labels. + Pre-processor for the Ava Actions Dataset stored as image frames. + `_` + This class handles the parsing of all the necessary CSV files containing + frame paths and frame labels. + + Attributes: + AVA_VALID_FRAMES (list): Range of valid annotated frames in Ava dataset. + FPS (int): Frames per second in the dataset. + AVA_VIDEO_START_SEC (int): Start time of the video in seconds. + + Class Methods: + _aggregate_bboxes_labels(cls, inp: Dict): + Aggregates bounding boxes and labels. + + from_csv(cls, frame_paths_file: str, frame_labels_file: str, video_path_prefix: str, + label_map_file: Optional[str] = None) -> AvaLabeledVideoFramePaths: + Creates an instance of AvaLabeledVideoFramePaths from CSV files. + + load_and_parse_labels_csv(frame_labels_file: str, video_name_to_idx: dict, + allowed_class_ids: Optional[Set] = None): + Parses AVA per-frame labels from a CSV file. + + load_image_lists(frame_paths_file: str, video_path_prefix: str) -> Tuple: + Loads image paths from a file and constructs dictionaries for video indexing. + + read_label_map(label_map_file: str) -> Tuple: + Reads the label map and class IDs from a .pbtxt file. """ # Range of valid annotated frames in Ava dataset @@ -70,25 +93,50 @@ def from_csv( label_map_file: Optional[str] = None, ) -> AvaLabeledVideoFramePaths: """ + Creates an AvaLabeledVideoFramePaths object from CSV files containing frame paths and labels. + Args: - frame_labels_file (str): Path to the file containing containing labels - per key frame. Acceptible file formats are, - Type 1: - - Type 2: - - frame_paths_file (str): Path to a file containing relative paths - to all the frames in the video. Each line in the file is of the - form - video_path_prefix (str): Path to be augumented to the each relative frame - path to get the global frame path. - label_map_file (str): Path to a .pbtxt containing class id's and class names. - If not set, label_map is not loaded and bbox labels are not pruned - based on allowable class_id's in label_map. - Returs: - A list of tuples of the the form (video_frames directory, label dictionary). + frame_paths_file (str): + Path to a file containing relative paths to all the frames in the video. Each line in the file + is of the form . + + frame_labels_file (str): + Path to the file containing labels per key frame. Acceptable file formats are as follows: + Type 1 (CSV Columns): + - original_video_id + - frame_time_stamp + - bbox_x1 + - bbox_y1 + - bbox_x2 + - bbox_y2 + - action_label + - detection_iou + + Type 2 (CSV Columns): + - original_video_id + - frame_time_stamp + - bbox_x1 + - bbox_y1 + - bbox_x2 + - bbox_y2 + - action_label + - person_label + + video_path_prefix (str): + Path to be augmented to each relative frame path to get the global frame path. + + label_map_file (str): + Path to a .pbtxt file containing class IDs and class names. If not set, the label map is not + loaded, and bbox labels are not pruned based on allowable class_ids in the label map. + + Returns: + AvaLabeledVideoFramePaths: An AvaLabeledVideoFramePaths object. + + This class method initializes an AvaLabeledVideoFramePaths object from CSV files containing frame paths + and labels. It processes these files to create a list of labeled video paths, where each entry is a tuple + containing the path to the video frames directory and a label dictionary. + + Note: This function assumes specific CSV file formats and column names. """ if label_map_file is not None: _, allowed_class_ids = AvaLabeledVideoFramePaths.read_label_map( @@ -132,31 +180,57 @@ def load_and_parse_labels_csv( allowed_class_ids: Optional[Set] = None, ): """ - Parses AVA per frame labels .csv file. + Parses AVA per-frame labels from a CSV file. + Args: - frame_labels_file (str): Path to the file containing labels - per key frame. Acceptible file formats are, - Type 1: - - Type 2: - - video_name_to_idx (dict): Dictionary mapping video names to indices. - allowed_class_ids (set): A set of integer unique class (bbox label) - id's that are allowed in the dataset. If not set, all class id's - are allowed in the bbox labels. + frame_labels_file (str): + Path to the file containing labels per key frame. Acceptable file formats are as follows: + Type 1 (CSV Columns): + - original_video_id + - frame_time_stamp + - bbox_x1 + - bbox_y1 + - bbox_x2 + - bbox_y2 + - action_label + - detection_iou + + Type 2 (CSV Columns): + - original_video_id + - frame_time_stamp + - bbox_x1 + - bbox_y1 + - bbox_x2 + - bbox_y2 + - action_label + - person_label + + video_name_to_idx (dict): + A dictionary mapping video names to indices. + + allowed_class_ids (set): + A set of unique integer class (bbox label) IDs that are allowed in the dataset. + If not set, all class IDs are allowed in the bbox labels. + Returns: - (dict): A dictionary of dictionary containing labels per each keyframe - in each video. Here, the label for each keyframe is again a dict - of the form, - { - 'labels': a list of bounding boxes - 'boxes':a list of action lables for the bounding box - 'extra_info': ist of extra information cotaining either - detections iou's or person id's depending on the - csv format. - } + dict: A dictionary containing labels for each keyframe in each video. The structure is as follows: + { + video_idx (int): { + frame_sec (float): { + 'labels': List of bounding box labels, + 'boxes': List of bounding boxes, + 'extra_info': List of extra information containing either detections' IoU or person IDs + }, + ... + }, + ... + } + + This function parses a CSV file containing per-frame labels, extracts the necessary information, + and organizes it into a nested dictionary structure. The structure allows easy access to labels, + bounding boxes, and extra information for each keyframe in each video. + + Note: This function assumes specific CSV file formats and column names. """ labels_dict = {} with g_pathmgr.open(frame_labels_file, "r") as f: @@ -202,20 +276,31 @@ def load_and_parse_labels_csv( @staticmethod def load_image_lists(frame_paths_file: str, video_path_prefix: str) -> Tuple: """ - Loading image paths from the corresponding file. + Loads image paths from the corresponding file. + Args: - frame_paths_file (str): Path to a file containing relative paths - to all the frames in the video. Each line in the file is of the - form - video_path_prefix (str): Path to be augumented to the each relative - frame path to get the global frame path. + frame_paths_file (str): + Path to a file containing relative paths to all the frames in the video. + Each line in the file is of the form . + + video_path_prefix (str): + Path to be augmented to each relative frame path to get the global frame path. + Returns: - (tuple): A tuple of the following, - image_paths_list: List of list containing absolute frame paths. - Wherein the outer list is per video and inner list is per - timestamp. - video_idx_to_name: A dictionary mapping video index to name - video_name_to_idx: A dictionary maoping video name to index + Tuple: + A tuple containing the following elements: + - image_paths_list (List[List[str]]): A list of lists containing absolute frame paths. + The outer list is per video, and the inner list is per timestamp. + - video_idx_to_name (Dict[int, str]): A dictionary mapping video index to video name. + - video_name_to_idx (Dict[str, int]): A dictionary mapping video name to video index. + + This function parses a file containing frame paths and their associated video information. + It organizes the frame paths into a list of lists, where each outer list represents a video, + and each inner list represents timestamps within that video. The video information is also + indexed and mapped for reference. + + The file format should follow: + original_video_id video_id frame_id path labels """ image_paths = [] @@ -255,15 +340,18 @@ def load_image_lists(frame_paths_file: str, video_path_prefix: str) -> Tuple: @staticmethod def read_label_map(label_map_file: str) -> Tuple: """ - Read label map and class ids. + Read a label map and extract class IDs and their associated class names. + Args: - label_map_file (str): Path to a .pbtxt containing class id's - and class names + label_map_file (str): The path to a .pbtxt file containing class IDs and class names. + Returns: - (tuple): A tuple of the following, - label_map (dict): A dictionary mapping class id to - the associated class names. - class_ids (set): A set of integer unique class id's + tuple: A tuple containing the following elements: + - label_map (Dict[int, str]): A dictionary mapping class IDs (integers) to their associated class names (strings). + - class_ids (Set[int]): A set containing unique class IDs (integers). + + This static method reads the contents of a .pbtxt file and extracts the class IDs and their associated class names. + It returns a tuple containing a dictionary that maps class IDs to class names and a set of unique class IDs. """ label_map = {} class_ids = set() @@ -282,16 +370,16 @@ def read_label_map(label_map_file: str) -> Tuple: class TimeStampClipSampler: """ - A sepcialized clip sampler for sampling video clips around specific - timestamps. This is particularly used in datasets like Ava wherein only - a specific subset of clips in the video have annotations + A specialized clip sampler for sampling video clips around specific timestamps. This is particularly used + in datasets like Ava where only a specific subset of clips in the video have annotations. """ def __init__(self, clip_sampler: ClipSampler) -> None: """ + Initializes a TimeStampClipSampler. + Args: - clip_sampler (`pytorchvideo.data.ClipSampler`): Strategy used for sampling - between the untrimmed clip boundary. + clip_sampler (ClipSampler): The strategy used for sampling between the untrimmed clip boundary. """ self.clip_sampler = clip_sampler @@ -299,14 +387,22 @@ def __call__( self, last_clip_time: float, video_duration: float, annotation: Dict[str, Any] ) -> ClipInfo: """ + Samples a video clip around a specific timestamp. + Args: last_clip_time (float): Not used for TimeStampClipSampler. - video_duration: (float): Not used for TimeStampClipSampler. - annotation (Dict): Dict containing time step to sample aroud. + video_duration (float): Not used for TimeStampClipSampler. + annotation (Dict): A dictionary containing the time step to sample around. + Returns: - clip_info (ClipInfo): includes the clip information of (clip_start_time, - clip_end_time, clip_index, aug_index, is_last_clip). The times are in seconds. - clip_index, aux_index and is_last_clip are always 0, 0 and True, respectively. + ClipInfo: An object including clip information with the following fields: + - clip_start_time (float): The start time of the sampled clip in seconds. + - clip_end_time (float): The end time of the sampled clip in seconds. + - clip_index (int): Always 0. + - aug_index (int): Always 0. + - is_last_clip (bool): Always True. + + The `center_frame_sec` in the annotation dictionary represents the timestamp around which the clip is sampled. """ center_frame_sec = annotation["clip_index"] # a.k.a timestamp clip_start_sec = center_frame_sec - self.clip_sampler._clip_duration / 2.0 @@ -319,9 +415,13 @@ def __call__( ) def reset(self) -> None: + """ + Resets the TimeStampClipSampler. + """ pass + def Ava( frame_paths_file: str, frame_labels_file: str, @@ -332,33 +432,39 @@ def Ava( transform: Optional[Callable[[dict], Any]] = None, ) -> None: """ + Creates a dataset for the AVA dataset with labeled video frames. + Args: frame_paths_file (str): Path to a file containing relative paths to all the frames in the video. Each line in the file is of the - form - frame_labels_file (str): Path to the file containing containing labels - per key frame. Acceptible file formats are, + form . + frame_labels_file (str): Path to the file containing labels + per key frame. Acceptable file formats are: Type 1: - + Type 2: - - video_path_prefix (str): Path to be augumented to the each relative frame - path to get the global frame path. - label_map_file (str): Path to a .pbtxt containing class id's - and class names. If not set, label_map is not loaded and bbox labels are - not pruned based on allowable class_id's in label_map. - clip_sampler (ClipSampler): Defines how clips should be sampled from each - video. + . + video_path_prefix (str): Path to be augmented to each relative frame + path to obtain the global frame path. + label_map_file (str): Path to a .pbtxt containing class IDs + and class names. If not set, the label_map is not loaded, and bbox labels are + not pruned based on allowable class_ids in label_map. + clip_sampler (ClipSampler): Defines how clips should be sampled from each video. video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal - video container. This defines the order videos are decoded and, - if necessary, the distributed split. + video container. This defines the order in which videos are decoded and, + if necessary, the distributed split. transform (Optional[Callable]): This callable is evaluated on the clip output and the corresponding bounding boxes before the clip and the bounding boxes - are returned. It can be used for user defined preprocessing and + are returned. It can be used for user-defined preprocessing and augmentations to the clips. If transform is None, the clip and bounding - boxes are returned as it is. + boxes are returned as they are. + + Returns: + LabeledVideoDataset: A dataset containing labeled video frames for the AVA dataset. + + This function reads frame paths and labels from specified files, constructs a dataset, and returns it. """ labeled_video_paths = AvaLabeledVideoFramePaths.from_csv( frame_paths_file, diff --git a/pytorchvideo/data/charades.py b/pytorchvideo/data/charades.py index c211a613..e0ab1a63 100644 --- a/pytorchvideo/data/charades.py +++ b/pytorchvideo/data/charades.py @@ -18,12 +18,14 @@ class Charades(torch.utils.data.IterableDataset): """ - Action recognition video dataset for + Action recognition video dataset for Charades stored as image frames. `Charades `_ stored as image frames. - This dataset handles the parsing of frames, loading and clip sampling for the - videos. All io is done through :code:`iopath.common.file_io.PathManager`, enabling - non-local storage uri's to be used. + This dataset handles the parsing of frames, loading, and clip sampling for the videos. + All I/O is done through `iopath.common.file_io.PathManager`, enabling non-local storage URIs to be used. + + Attributes: + NUM_CLASSES (int): Number of classes represented by this dataset's annotated labels. """ # Number of classes represented by this dataset's annotated labels. @@ -39,23 +41,22 @@ def __init__( frames_per_clip: Optional[int] = None, ) -> None: """ + Initializes a Charades dataset. + Args: - data_path (str): Path to the data file. This file must be a space - separated csv with the format: (original_vido_id video_id frame_id - path_labels) + data_path (str): Path to the data file. This file must be a space-separated CSV with the format: + (original_video_id video_id frame_id path_labels) - clip_sampler (ClipSampler): Defines how clips should be sampled from each - video. See the clip sampling documentation for more information. + clip_sampler (ClipSampler): Defines how clips should be sampled from each video. - video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal - video container. This defines the order videos are decoded and, - if necessary, the distributed split. + video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal video container. + This defines the order videos are decoded and, if necessary, the distributed split. transform (Optional[Callable]): This callable is evaluated on the clip output before - the clip is returned. It can be used for user defined preprocessing and - augmentations on the clips. The clip output format is described in __next__(). + the clip is returned. It can be used for user-defined preprocessing and augmentations + on the clips. - video_path_prefix (str): prefix path to add to all paths from data_path. + video_path_prefix (str): Prefix path to add to all paths from data_path. frames_per_clip (Optional[int]): The number of frames per clip to sample. """ @@ -92,12 +93,14 @@ def _sample_clip_frames( frame_indices: List[int], frames_per_clip: int ) -> List[int]: """ + Subsamples a list of frame indices to obtain a clip with a specified number of frames. + Args: - frame_indices (list): list of frame indices. - frames_per+clip (int): The number of frames per clip to sample. + frame_indices (list): List of frame indices. + frames_per_clip (int): The number of frames per clip to sample. Returns: - (list): Outputs a subsampled list with num_samples frames. + list: Subsampled list of frame indices for the clip. """ num_frames = len(frame_indices) indices = torch.linspace(0, num_frames - 1, frames_per_clip) @@ -107,6 +110,9 @@ def _sample_clip_frames( @property def video_sampler(self) -> torch.utils.data.Sampler: + """ + Returns the video sampler used by this dataset. + """ return self._video_sampler def __next__(self) -> dict: @@ -184,14 +190,32 @@ def _read_video_paths_and_labels( video_path_label_file: List[str], prefix: str = "" ) -> Tuple[List[str], List[int]]: """ - Args: - video_path_label_file (List[str]): a file that contains frame paths for each - video and the corresponding frame label. The file must be a space separated - csv of the format: - `original_vido_id video_id frame_id path labels` - - prefix (str): prefix path to add to all paths from video_path_label_file. + Reads video frame paths and associated labels from a CSV file. + Args: + video_path_label_file (List[str]): A file containing frame paths for each video + and their corresponding frame labels. The file must be a space-separated CSV + with the following format: + original_vido_id video_id frame_id path labels + + prefix (str): A prefix path to add to all frame paths from the CSV file. + + Returns: + Tuple[List[str], List[int]]: A tuple containing lists of video frame paths and + their associated labels. + + Example: + Given a CSV file with the following format: + ``` + original_vido_id video_id frame_id path labels + video1 1 frame1.jpg /path/to/frames/1.jpg "1,2,3" + video1 1 frame2.jpg /path/to/frames/2.jpg "2,3" + video2 2 frame1.jpg /path/to/frames/1.jpg "4,5" + ``` + + The function would return the following tuple: + (['/path/to/frames/1.jpg', '/path/to/frames/2.jpg', '/path/to/frames/1.jpg'], + [[1, 2, 3], [2, 3], [4, 5]]) """ image_paths = defaultdict(list) labels = defaultdict(list) diff --git a/pytorchvideo/data/clip_sampling.py b/pytorchvideo/data/clip_sampling.py index f59c5c1e..d40d501d 100644 --- a/pytorchvideo/data/clip_sampling.py +++ b/pytorchvideo/data/clip_sampling.py @@ -8,13 +8,15 @@ class ClipInfo(NamedTuple): """ - Named-tuple for clip information with: - clip_start_sec (Union[float, Fraction]): clip start time. - clip_end_sec (Union[float, Fraction]): clip end time. - clip_index (int): clip index in the video. - aug_index (int): augmentation index for the clip. Different augmentation methods + Contains information about a video clip. + + Attributes: + clip_start_sec (Union[float, Fraction]): The start time of the clip. + clip_end_sec (Union[float, Fraction]): The end time of the clip. + clip_index (int): The index of the clip in the video. + aug_index (int): The augmentation index for the clip. Different augmentation methods might generate multiple views for the same clip. - is_last_clip (bool): a bool specifying whether there are more clips to be + is_last_clip (bool): A boolean specifying whether there are more clips to be sampled from the video. """ @@ -27,13 +29,15 @@ class ClipInfo(NamedTuple): class ClipInfoList(NamedTuple): """ - Named-tuple for clip information with: - clip_start_sec (float): clip start time. - clip_end_sec (float): clip end time. - clip_index (int): clip index in the video. - aug_index (int): augmentation index for the clip. Different augmentation methods + Contains lists of clip information. + + Attributes: + clip_start_sec (List[float]): List of clip start times. + clip_end_sec (List[float]): List of clip end times. + clip_index (List[int]): List of clip indices in the video. + aug_index (List[int]): List of augmentation indices for the clips. Different augmentation methods might generate multiple views for the same clip. - is_last_clip (bool): a bool specifying whether there are more clips to be + is_last_clip (List[bool]): List of booleans specifying whether there are more clips to be sampled from the video. """ @@ -47,10 +51,17 @@ class ClipInfoList(NamedTuple): class ClipSampler(ABC): """ Interface for clip samplers that take a video time, previous sampled clip time, - and returns a named-tuple ``ClipInfo``. + and returns a `ClipInfo` namedtuple. """ def __init__(self, clip_duration: Union[float, Fraction]) -> None: + """ + Initializes a `ClipSampler` with a specified clip duration. + + Args: + clip_duration (Union[float, Fraction]): The duration of each sampled clip. + """ + self._clip_duration = Fraction(clip_duration) self._current_clip_index = 0 self._current_aug_index = 0 @@ -62,26 +73,43 @@ def __call__( video_duration: Union[float, Fraction], annotation: Dict[str, Any], ) -> ClipInfo: + """ + Samples the next video clip and returns its information. + + Args: + last_clip_end_time (Union[float, Fraction]): The end time of the last sampled clip. + video_duration (Union[float, Fraction]): The total duration of the video. + annotation (Dict[str, Any]): Additional annotation or information. + + Returns: + ClipInfo: A namedtuple containing information about the sampled clip. + """ pass def reset(self) -> None: - """Resets any video-specific attributes in preperation for next video""" + """Resets any video-specific attributes in preparation for the next video.""" pass def make_clip_sampler(sampling_type: str, *args) -> ClipSampler: """ - Constructs the clip samplers found in ``pytorchvideo.data.clip_sampling`` from the - given arguments. + Constructs a clip sampler based on the specified sampling type and arguments. Args: - sampling_type (str): choose clip sampler to return. It has three options: + sampling_type (str): The type of clip sampler to create. Supported options are: + + - "uniform": Constructs and returns a UniformClipSampler. + - "random": Constructs and returns a RandomClipSampler. + - "constant_clips_per_video": Constructs and returns a ConstantClipsPerVideoSampler. + - "random_multi": Constructs and returns a RandomMultiClipSampler. + + *args: Additional arguments to pass to the chosen clip sampler constructor. - * uniform: constructs and return ``UniformClipSampler`` - * random: construct and return ``RandomClipSampler`` - * constant_clips_per_video: construct and return ``ConstantClipsPerVideoSampler`` + Returns: + ClipSampler: An instance of the selected clip sampler based on the specified type. - *args: the args to pass to the chosen clip sampler constructor. + Raises: + NotImplementedError: If the specified sampling_type is not supported. """ if sampling_type == "uniform": return UniformClipSampler(*args) @@ -108,29 +136,31 @@ def __init__( eps: float = 1e-6, ): """ + Initializes a UniformClipSampler. + Args: clip_duration (Union[float, Fraction]): The length of the clip to sample (in seconds). stride (Union[float, Fraction], optional): - The amount of seconds to offset the next clip by - default value of None is equivalent to no stride => stride == clip_duration. + The amount of seconds to offset the next clip by. + If None, it defaults to `clip_duration`, meaning no overlap between clips. eps (float): Epsilon for floating point comparisons. Used to check the last clip. backpad_last (bool): Whether to include the last frame(s) by "back padding". For instance, if we have a video of 39 frames (30 fps = 1.3s) - with a stride of 16 (0.533s) with a clip duration of 32 frames - (1.0667s). The clips will be (in frame numbers): - - with backpad_last = False + with a stride of 16 (0.533s) and a clip duration of 32 frames (1.0667s). + Clips without backpad_last: - [0, 31] - with backpad_last = True + Clips with backpad_last: - [0, 31] - [8, 39], this is "back-padded" from [16, 48] to fit the last window - Note that you can use Fraction for clip_duration and stride if you want to - avoid float precision issue and need accurate frames in each clip. + + Note: + You can use Fraction for `clip_duration` and `stride` to avoid float precision + issues and obtain accurate frame counts in each clip. """ super().__init__(clip_duration) self._stride = stride if stride is not None else self._clip_duration @@ -146,7 +176,21 @@ def _clip_start_end( backpad_last: bool, ) -> Tuple[Fraction, Fraction]: """ - Helper to calculate the start/end clip with backpad logic + Calculates the start and end time of the next clip with optional back padding logic. + + Args: + last_clip_end_time (Union[float, Fraction]): + The end time of the previous clip sampled from the video. + Should be 0.0 if the video hasn't had clips sampled yet. + video_duration (Union[float, Fraction]): + The duration of the video being sampled in seconds. + backpad_last (bool): + Whether to include the last frame(s) by "back padding". + + Returns: + Tuple[Fraction, Fraction]: A tuple containing the start and end times of the clip + in seconds (Fractions if used). The clip's end time may be adjusted if back padding + is enabled to ensure it doesn't exceed the video duration. """ delta = self._stride - self._clip_duration last_end_time = -delta if last_clip_end_time is None else last_clip_end_time @@ -167,15 +211,18 @@ def __call__( annotation: Dict[str, Any], ) -> ClipInfo: """ + Samples the next clip from the video. + Args: - last_clip_end_time (float): the last clip end time sampled from this video. This - should be 0.0 if the video hasn't had clips sampled yet. - video_duration: (float): the duration of the video that's being sampled in seconds + last_clip_end_time (float): The last clip end time sampled from this video. + Should be 0.0 if the video hasn't had clips sampled yet. + video_duration: (float): The duration of the video being sampled in seconds. annotation (Dict): Not used by this sampler. + Returns: - clip_info: (ClipInfo): includes the clip information (clip_start_time, + clip_info: (ClipInfo): Includes the clip information (clip_start_time, clip_end_time, clip_index, aug_index, is_last_clip), where the times are in - seconds and is_last_clip is False when there is still more of time in the video + seconds, and is_last_clip is False when there is still more time in the video to be sampled. """ clip_start, clip_end = self._clip_start_end( @@ -218,6 +265,27 @@ def __init__( eps: float = 1e-6, truncation_duration: float = None, ) -> None: + """ + Initializes a UniformClipSamplerTruncateFromStart. + + Args: + clip_duration (Union[float, Fraction]): + The length of the clip to sample (in seconds). + stride (Union[float, Fraction], optional): + The amount of seconds to offset the next clip by. + If None, it defaults to `clip_duration`, meaning no overlap between clips. + eps (float): + Epsilon for floating point comparisons. Used to check the last clip. + backpad_last (bool): + Whether to include the last frame(s) by "back padding". + truncation_duration (float, optional): + The maximum duration to truncate the video to. Clips will be sampled from + [0, truncation_duration] if set. + + Note: + You can use Fraction for `clip_duration` and `stride` to avoid float precision + issues and obtain accurate frame counts in each clip. + """ super().__init__(clip_duration, stride, backpad_last, eps) self.truncation_duration = truncation_duration @@ -227,6 +295,20 @@ def __call__( video_duration: float, annotation: Dict[str, Any], ) -> ClipInfo: + """ + Samples the next clip from the video. + + Args: + last_clip_end_time (float): The last clip end time sampled from this video. + video_duration: (float): The duration of the video being sampled in seconds. + annotation (Dict): Not used by this sampler. + + Returns: + clip_info: (ClipInfo): Includes the clip information (clip_start_time, + clip_end_time, clip_index, aug_index, is_last_clip), where the times are in + seconds, and is_last_clip is False when there is still more time in the video + to be sampled. + """ truncated_video_duration = video_duration if self.truncation_duration is not None: @@ -239,7 +321,7 @@ def __call__( class RandomClipSampler(ClipSampler): """ - Randomly samples clip of size clip_duration from the videos. + Randomly samples clips of size clip_duration from the videos. """ def __call__( @@ -249,16 +331,17 @@ def __call__( annotation: Dict[str, Any], ) -> ClipInfo: """ + Samples a random clip of the specified duration from the video. + Args: last_clip_end_time (float): Not used for RandomClipSampler. - video_duration: (float): the duration (in seconds) for the video that's - being sampled + video_duration (float): The duration (in seconds) of the video being sampled. annotation (Dict): Not used by this sampler. + Returns: - clip_info (ClipInfo): includes the clip information of (clip_start_time, + clip_info (ClipInfo): Contains clip information (clip_start_time, clip_end_time, clip_index, aug_index, is_last_clip). The times are in seconds. - clip_index, aux_index and is_last_clip are always 0, 0 and True, respectively. - + clip_index, aug_index, and is_last_clip are always 0, 0, and True, respectively. """ max_possible_clip_start = max(video_duration - self._clip_duration, 0) clip_start_sec = Fraction(random.uniform(0, max_possible_clip_start)) @@ -269,7 +352,22 @@ def __call__( class RandomMultiClipSampler(RandomClipSampler): """ - Randomly samples multiple clips of size clip_duration from the videos. + Randomly samples multiple clips of a specified duration from videos. + + This class extends RandomClipSampler to sample multiple clips from videos. It randomly selects + 'num_clips' clips of size 'clip_duration' from the given video, ensuring randomness and diversity. + + Args: + clip_duration (float): The duration of each sampled clip in seconds. + num_clips (int): The number of clips to sample from the video. + + Attributes: + _num_clips (int): The number of clips to be sampled. + + Methods: + __call__(self, last_clip_end_time, video_duration, annotation): + Randomly selects 'num_clips' clips from the video, using the underlying RandomClipSampler. + Returns information about the sampled clips in a ClipInfoList. """ def __init__(self, clip_duration: float, num_clips: int) -> None: @@ -282,7 +380,19 @@ def __call__( video_duration: float, annotation: Dict[str, Any], ) -> ClipInfoList: + """ + Randomly selects 'num_clips' clips of duration 'clip_duration' from the video. + + Args: + last_clip_end_time (float or None): The end time of the last sampled clip, or None if no previous clip. + video_duration (float): The total duration of the video in seconds. + annotation (dict): Additional annotation data associated with the video. + Returns: + ClipInfoList: A list of ClipInfo objects containing information about the sampled clips, + including start and end times, clip indices, augmentation indices, and flags indicating + if a clip is the last in the sequence. + """ ( clip_start_list, clip_end_list, @@ -316,9 +426,25 @@ def __call__( class RandomMultiClipSamplerTruncateFromStart(RandomMultiClipSampler): """ - Randomly samples multiple clips of size clip_duration from the videos. - If truncation_duration is set, clips sampled from [0, truncation_duration]. - If truncation_duration is not set, defaults to RandomMultiClipSampler. + Randomly samples multiple clips of a specified duration from videos with optional truncation. + + This class extends RandomMultiClipSampler to sample multiple clips from videos. It provides + an option to truncate the video to a specified duration before sampling clips from the beginning. + + Args: + clip_duration (float): The duration of each sampled clip in seconds. + num_clips (int): The number of clips to sample from the video. + truncation_duration (float or None, optional): The duration to truncate the video to + before sampling clips. If None, the entire video is used. + + Attributes: + truncation_duration (float or None): The duration to truncate the video to, if specified. + + Methods: + __call__(self, last_clip_end_time, video_duration, annotation): + Randomly selects 'num_clips' clips of duration 'clip_duration' from the video, with + optional truncation. + Returns information about the sampled clips in a ClipInfoList. """ def __init__( @@ -333,7 +459,19 @@ def __call__( video_duration: float, annotation: Dict[str, Any], ) -> ClipInfoList: + """ + Randomly selects 'num_clips' clips of duration 'clip_duration' from the video, with optional truncation. + + Args: + last_clip_end_time (float or None): The end time of the last sampled clip, or None if no previous clip. + video_duration (float): The total duration of the video in seconds. + annotation (dict): Additional annotation data associated with the video. + Returns: + ClipInfoList: A list of ClipInfo objects containing information about the sampled clips, + including start and end times, clip indices, augmentation indices, and flags indicating + if a clip is the last in the sequence. + """ truncated_video_duration = video_duration if self.truncation_duration is not None: truncated_video_duration = min(self.truncation_duration, video_duration) @@ -345,8 +483,25 @@ def __call__( class ConstantClipsPerVideoSampler(ClipSampler): """ - Evenly splits the video into clips_per_video increments and samples clips of size - clip_duration at these increments. + Evenly splits a video into a fixed number of clips and samples clips of a specified duration. + + This class evenly divides a video into 'clips_per_video' increments and samples clips of + size 'clip_duration' at these increments. It allows for multiple augmentations per clip. + + Args: + clip_duration (float): The duration of each sampled clip in seconds. + clips_per_video (int): The number of clips to evenly sample from the video. + augs_per_clip (int, optional): The number of augmentations to apply to each sampled clip. Default is 1. + + Attributes: + _clips_per_video (int): The fixed number of clips to sample from the video. + _augs_per_clip (int): The number of augmentations to apply per clip. + + Methods: + __call__(self, last_clip_end_time, video_duration, annotation): + Samples the next clip and returns clip information. + reset(self): + Resets the internal state for sampling clips. """ def __init__( @@ -363,6 +518,8 @@ def __call__( annotation: Dict[str, Any], ) -> ClipInfo: """ + Samples the next clip from the video. + Args: last_clip_end_time (float): Not used for ConstantClipsPerVideoSampler. video_duration: (float): the duration (in seconds) for the video that's @@ -409,5 +566,11 @@ def __call__( ) def reset(self): + """ + Resets the internal state for sampling clips. + + This method resets the internal indices used for sampling clips, allowing you to start + sampling from the beginning of the video again. + """ self._current_clip_index = 0 self._current_aug_index = 0 diff --git a/pytorchvideo/data/dataset_manifest_utils.py b/pytorchvideo/data/dataset_manifest_utils.py index 948dbde6..74dcd2e8 100644 --- a/pytorchvideo/data/dataset_manifest_utils.py +++ b/pytorchvideo/data/dataset_manifest_utils.py @@ -84,6 +84,30 @@ def _load_images( video_info_file_path: str, multithreaded_io: bool, ) -> Dict[str, ImageFrameInfo]: + """ + Load image frame information from data files and create a dictionary of ImageFrameInfo objects. + + This static method reads information about image frames from data files specified by + 'frame_manifest_file_path' and 'video_info_file_path' and organizes it into a dictionary + of ImageFrameInfo objects. It ensures consistency and completeness of data between video + information and frame information. + + Args: + frame_manifest_file_path (str or None): The file path to the manifest containing frame information. + If None, frame information will not be loaded. + video_info_file_path (str): The file path to the CSV file containing video information. + multithreaded_io (bool): A flag indicating whether to use multithreaded I/O operations. + + Returns: + Dict[str, ImageFrameInfo]: A dictionary where the keys are frame IDs, and the values + are ImageFrameInfo objects containing information about each image frame. + + Note: + - If 'frame_manifest_file_path' is None, frame information will not be loaded. + - The 'frame_manifest_file_path' and 'video_info_file_path' CSV files must have a common + key for matching video and frame data. + + """ video_infos: Dict[str, VideoInfo] = load_dataclass_dict_from_csv( video_info_file_path, VideoInfo, "video_id" ) @@ -122,6 +146,29 @@ def _load_videos( multithreaded_io: bool, dataset_type: VideoDatasetType, ) -> Dict[str, Video]: + """ + Load videos or frame data and create a dictionary of Video objects. + + This static method loads video data or frame information from specified data files and organizes + it into a dictionary of Video objects. The type of dataset loaded depends on the 'dataset_type' + parameter. + + Args: + video_data_manifest_file_path (str or None): The file path to the manifest containing video or + frame data. If None, video data or frame data will not be loaded. + video_info_file_path (str): The file path to the CSV file containing video information. + multithreaded_io (bool): A flag indicating whether to use multithreaded I/O operations. + dataset_type (VideoDatasetType): The type of dataset to load, either Frame or EncodedVideo. + + Returns: + Dict[str, Video]: A dictionary where the keys are video IDs, and the values are Video objects. + + Note: + - If 'video_data_manifest_file_path' is None, video data or frame data will not be loaded. + - The 'video_data_manifest_file_path' and 'video_info_file_path' CSV files must have a common + key for matching video and frame data. + """ + video_infos: Dict[str, VideoInfo] = load_dataclass_dict_from_csv( video_info_file_path, VideoInfo, "video_id" ) @@ -140,6 +187,22 @@ def _load_frame_videos( video_infos: Dict[str, VideoInfo], multithreaded_io: bool, ): + """ + Load frame videos and create a dictionary of FrameVideo objects. + + This static method loads frame video data from the specified frame manifest file and organizes it + into a dictionary of FrameVideo objects. It ensures consistency and completeness of data between + video information and frame information. + + Args: + frame_manifest_file_path (str): The file path to the manifest containing frame information. + video_infos (Dict[str, VideoInfo]): A dictionary of video information keyed by video ID. + multithreaded_io (bool): A flag indicating whether to use multithreaded I/O operations. + + Returns: + Dict[str, FrameVideo]: A dictionary where the keys are video IDs, and the values are FrameVideo + objects containing frame video data. + """ video_frames: Dict[str, VideoFrameInfo] = load_dataclass_dict_from_csv( frame_manifest_file_path, VideoFrameInfo, "video_id" ) @@ -163,6 +226,22 @@ def _load_encoded_videos( encoded_video_manifest_file_path: str, video_infos: Dict[str, VideoInfo], ): + """ + Load encoded videos and create a dictionary of EncodedVideo objects. + + This static method loads encoded video data from the specified encoded video manifest file and + organizes it into a dictionary of EncodedVideo objects. It ensures consistency and completeness of + data between video information and encoded video information. + + Args: + encoded_video_manifest_file_path (str): The file path to the manifest containing encoded video + information. + video_infos (Dict[str, VideoInfo]): A dictionary of video information keyed by video ID. + + Returns: + Dict[str, EncodedVideo]: A dictionary where the keys are video IDs, and the values are EncodedVideo + objects containing encoded video data. + """ encoded_video_infos: Dict[str, EncodedVideoInfo] = load_dataclass_dict_from_csv( encoded_video_manifest_file_path, EncodedVideoInfo, "video_id" ) @@ -181,6 +260,20 @@ def _frame_number_to_filepaths( video_frames: Dict[str, VideoFrameInfo], video_infos: Dict[str, VideoInfo], ) -> Optional[str]: + """ + Convert frame numbers to file paths. + + This static method generates file paths for frame numbers based on video frame information and video + information. + + Args: + video_id (str): The ID of the video. + video_frames (Dict[str, VideoFrameInfo]): A dictionary of video frame information keyed by video ID. + video_infos (Dict[str, VideoInfo]): A dictionary of video information keyed by video ID. + + Returns: + Optional[str]: A list of file paths for frames or None if frame numbers are invalid. + """ video_info = video_infos[video_id] video_frame_info = video_frames[video_info.video_id] @@ -216,6 +309,17 @@ def _remove_video_info_missing_or_incomplete_videos( video_data_infos: Dict[str, Union[VideoFrameInfo, EncodedVideoInfo]], video_infos: Dict[str, VideoInfo], ) -> None: + """ + Remove video information for missing or incomplete videos. + + This static method removes video information for videos that are missing corresponding video data + or do not have the correct number of frames. + + Args: + video_data_infos (Dict[str, Union[VideoFrameInfo, EncodedVideoInfo]]): A dictionary of video + data information keyed by video ID. + video_infos (Dict[str, VideoInfo]): A dictionary of video information keyed by video ID. + """ # Avoid deletion keys from dict during iteration over keys video_ids = list(video_infos) for video_id in video_ids: @@ -248,14 +352,29 @@ def _remove_video_info_missing_or_incomplete_videos( def get_seconds_from_hms_time(time_str: str) -> float: """ - Get Seconds from timestamp of form 'HH:MM:SS'. + Convert a timestamp of the form 'HH:MM:SS' or 'HH:MM:SS.sss' to seconds. Args: - time_str (str) + time_str (str): A string representing a timestamp in the format 'HH:MM:SS' or 'HH:MM:SS.sss'. Returns: - float of seconds + float: The equivalent time in seconds. + Raises: + ValueError: If the provided string is not in a valid time format. + + This function parses the input 'time_str' as a timestamp in either 'HH:MM:SS' or 'HH:MM:SS.sss' format. + It then calculates and returns the equivalent time in seconds as a floating-point number. + + Example: + - Input: '01:23:45' + Output: 5025.0 seconds + - Input: '00:00:01.234' + Output: 1.234 seconds + + Note: + - The function supports both fractional seconds and integer seconds. + - If the input string is not in a valid time format, a ValueError is raised. """ for fmt in ("%H:%M:%S.%f", "%H:%M:%S"): try: @@ -271,18 +390,24 @@ def save_encoded_video_manifest( encoded_video_infos: Dict[str, EncodedVideoInfo], file_name: str = None ) -> str: """ - Saves the encoded video dictionary as a csv file that can be read for future usage. + Save a dictionary of encoded video information as a CSV file. - Args: - video_frames (Dict[str, EncodedVideoInfo]): - Dictionary mapping video_ids to metadata about the location of - their video data. + This function takes a dictionary of encoded video information, where keys are video IDs and values + are EncodedVideoInfo objects, and saves it as a CSV file. The CSV file can be used for future + reference and data retrieval. - file_name (str): - location to save file (will be automatically generated if None). + Args: + encoded_video_infos (Dict[str, EncodedVideoInfo]): + A dictionary mapping video IDs to metadata about the location of their encoded video data. + file_name (str, optional): + The file name or path where the CSV file will be saved. If not provided, a file name + will be automatically generated in the current working directory. Returns: - string of the filename where the video info is stored. + str: The filename where the encoded video information is stored. + + Note: + - The CSV file will have a header row with column names based on the EncodedVideoInfo data class. """ file_name = ( f"{os.getcwd()}/encoded_video_manifest.csv" if file_name is None else file_name @@ -295,18 +420,24 @@ def save_video_frame_info( video_frames: Dict[str, VideoFrameInfo], file_name: str = None ) -> str: """ - Saves the video frame dictionary as a csv file that can be read for future usage. + Save a dictionary of video frame information as a CSV file. + + This function takes a dictionary of video frame information, where keys are video IDs and values + are VideoFrameInfo objects, and saves it as a CSV file. The CSV file can be used for future + reference and data retrieval. Args: video_frames (Dict[str, VideoFrameInfo]): - Dictionary mapping video_ids to metadata about the location of - their video frame files. - - file_name (str): - location to save file (will be automatically generated if None). + A dictionary mapping video IDs to metadata about the location of their video frame files. + file_name (str, optional): + The file name or path where the CSV file will be saved. If not provided, a file name + will be automatically generated in the current working directory. Returns: - string of the filename where the video info is stored. + str: The filename where the video frame information is stored. + + Note: + - The CSV file will have a header row with column names based on the VideoFrameInfo data class. """ file_name = ( f"{os.getcwd()}/video_frame_metadata.csv" if file_name is None else file_name diff --git a/pytorchvideo/data/domsev.py b/pytorchvideo/data/domsev.py index 74f07490..2685b4e9 100644 --- a/pytorchvideo/data/domsev.py +++ b/pytorchvideo/data/domsev.py @@ -78,7 +78,16 @@ class LabelType(Enum): @dataclass class LabelData(DataclassFieldCaster): """ - Class representing a contiguous label for a video segment from the DoMSEV dataset. + Represents a continuous label for a video segment from the DoMSEV dataset. + + Attributes: + video_id (str): The unique identifier of the video. + start_time (float): The start time of the label, in seconds. + stop_time (float): The stop time of the label, in seconds. + start_frame (int): The 0-indexed ID of the start frame (inclusive). + stop_frame (int): The 0-indexed ID of the stop frame (inclusive). + label_id (int): The unique identifier of the label. + label_name (str): The name of the label. """ video_id: str @@ -96,7 +105,7 @@ def _seconds_to_frame_index( ) -> int: """ Converts a point in time (in seconds) within a video clip to its closest - frame indexed (rounding down), based on a specified frame rate. + frame index (rounding down), based on a specified frame rate. Args: time_in_seconds (float): The point in time within the video. @@ -105,7 +114,7 @@ def _seconds_to_frame_index( zero-indexed (if True) or one-indexed (if False). Returns: - (int) The index of the nearest frame (rounding down to the nearest integer). + int: The index of the nearest frame (rounding down to the nearest integer). """ frame_idx = math.floor(time_in_seconds * fps) if not zero_indexed: @@ -119,8 +128,14 @@ def _get_overlap_for_time_range_pair( """ Calculates the overlap between two time ranges, if one exists. + Args: + t1_start (float): The start time of the first time range. + t1_stop (float): The stop time of the first time range. + t2_start (float): The start time of the second time range. + t2_stop (float): The stop time of the second time range. + Returns: - (Optional[Tuple]) A tuple of if + Optional[Tuple[float, float]]: A tuple of if an overlap is found, or None otherwise. """ # Check if there is an overlap @@ -135,11 +150,30 @@ def _get_overlap_for_time_range_pair( class DomsevFrameDataset(torch.utils.data.Dataset): """ - Egocentric video classification frame-based dataset for + Dataset for frame-based egocentric video classification using the DoMSEV dataset. `DoMSEV `_ - This dataset handles the loading, decoding, and configurable sampling for - the image frames. + This dataset handles loading, decoding, and configurable sampling of image frames. + + Args: + video_data_manifest_file_path (str): Path to a JSON file outlining available video data + for associated videos. + video_info_file_path (str): Path or URI to a manifest with basic metadata for each video. + labels_file_path (str): Path or URI to a manifest with temporal annotations for each video. + transform (Optional[Callable[[Dict[str, Any]], Any]]): A callable for custom preprocessing + and augmentations to apply to the clips. Default is None. + multithreaded_io (bool): Control whether IO operations are performed across multiple threads. + Default is False. + + Attributes: + _labels_per_frame (Dict[str, int]): A mapping of frame IDs to their corresponding label IDs. + _user_transform (Optional[Callable[[Dict[str, Any]], Any]]): User-defined transform function. + _transform (Callable[[Dict[str, Any]], Dict[str, Any]]): Default transformation function. + _frames (List[ImageFrameInfo]): List of image frame information. + + Methods: + __getitem__(self, index) -> Dict[str, Any]: Sample an image frame associated with the given index. + __len__(self) -> int: Get the number of frames in the dataset. """ def __init__( @@ -214,10 +248,11 @@ def _assign_labels_to_frames( video_labels: Dict[str, List[LabelData]], ): """ + Assign labels to frames based on temporal annotations. + Args: - frames_dict: The mapping of for all the frames - in the dataset. - video_labels: The list of temporal labels for each video + frames_dict (Dict[str, ImageFrameInfo]): Mapping of frame_id to ImageFrameInfo. + video_labels (Dict[str, List[LabelData]]): Temporal annotations for each video. Also unpacks one label per frame. Also converts them to class IDs and then a tensor. @@ -237,13 +272,13 @@ def _assign_labels_to_frames( def __getitem__(self, index) -> Dict[str, Any]: """ - Samples an image frame associated to the given index. + Get an image frame and associated information at the specified index. Args: - index (int): index for the image frame + index (int): Index for the image frame. Returns: - An image frame with the following format if transform is None. + Dict[str, Any]: Information about the image frame and its label. .. code-block:: text @@ -271,21 +306,22 @@ def __getitem__(self, index) -> Dict[str, Any]: def __len__(self) -> int: """ + Get the number of frames in the dataset. + Returns: - The number of frames in the dataset. + int: The number of frames. """ return len(self._frames) def _transform_frame(self, frame: Dict[str, Any]) -> Dict[str, Any]: """ - Transforms a given image frame, according to some pre-defined transforms - and an optional user transform function (self._user_transform). + Apply transformations to a given image frame. Args: - clip (Dict[str, Any]): The clip that will be transformed. + frame (Dict[str, Any]): Information about the image frame. Returns: - (Dict[str, Any]) The transformed clip. + Dict[str, Any]: Transformed information about the image frame. """ for key in frame: if frame[key] is None: @@ -494,16 +530,18 @@ def _transform_clip(self, clip: Dict[str, Any]) -> Dict[str, Any]: def _load_image_from_path(image_path: str, num_retries: int = 10) -> Image: """ - Loads the given image path using PathManager and decodes it as an RGB image. + Load an image from the given file path and decode it as an RGB image. Args: - image_path (str): the path to the image. - num_retries (int): number of times to retry image reading to handle transient error. + image_path (str): The path to the image file. + num_retries (int): The number of times to retry image reading to handle transient errors. Returns: - A PIL Image of the image RGB data with shape: - (channel, height, width). The frames are of type np.uint8 and - in the range [0 - 255]. Raises an exception if unable to load images. + Image: A PIL Image representing the loaded image in RGB format. + The image has the shape (channel, height, width) and pixel values in the range [0, 255]. + + Raises: + Exception: If unable to load the image after the specified number of retries. """ if not _HAS_CV2: raise ImportError( diff --git a/pytorchvideo/data/ego4d/ego4d_dataset.py b/pytorchvideo/data/ego4d/ego4d_dataset.py index e41c27a7..0b1cfccf 100644 --- a/pytorchvideo/data/ego4d/ego4d_dataset.py +++ b/pytorchvideo/data/ego4d/ego4d_dataset.py @@ -39,7 +39,15 @@ class Ego4dImuData(Ego4dImuDataBase): """ - Wrapper for Ego4D IMU data loads, assuming one csv per video_uid at the provided path. + A wrapper for loading Ego4D IMU data, assuming one CSV file per video_uid at the provided path. + + Args: + imu_path (str): The base path to construct IMU CSV file paths. + Example format: /.csv + + This class is designed to facilitate the loading of IMU data for Ego4D videos. It assumes that + there is one CSV file per video_uid at the specified `imu_path`. The data is loaded and stored + for easy access. """ def __init__(self, imu_path: str) -> None: @@ -64,9 +72,27 @@ def __init__(self, imu_path: str) -> None: self.imu_video_data: Optional[Tuple[np.ndarray, np.ndarray, int]] = None def has_imu(self, video_uid: str) -> bool: + """ + Check if IMU data is available for a given video UID. + + Args: + video_uid (str): The video UID to check. + + Returns: + bool: True if IMU data is available, otherwise False. + """ return video_uid in self.IMU_by_video_uid def _load_csv(self, csv_path: str) -> List[Dict[str, Any]]: + """ + Load data from a CSV file. + + Args: + csv_path (str): The path to the CSV file. + + Returns: + List[Dict[str, Any]]: A list of dictionaries representing the CSV data. + """ with g_pathmgr.open(csv_path, "r") as f: reader = csv.DictReader(f) data = [] @@ -75,6 +101,16 @@ def _load_csv(self, csv_path: str) -> List[Dict[str, Any]]: return data def _load_imu(self, video_uid: str) -> Tuple[np.ndarray, np.ndarray, int]: + """ + Load IMU data for a given video UID. + + Args: + video_uid (str): The video UID for which to load IMU data. + + Returns: + Tuple[np.ndarray, np.ndarray, int]: A tuple containing the IMU signal, timestamps, + and sampling rate. + """ file_path = os.path.join(self.path_imu, video_uid) + ".csv" data_csv = self._load_csv(file_path) data_IMU = defaultdict(list) @@ -113,6 +149,20 @@ def _get_imu_window( timestamps: np.ndarray, sampling_rate: float, ) -> Dict[str, Any]: + """ + Retrieve IMU data for a specified time window. + + Args: + window_start (float): The start time of the window in seconds. + window_end (float): The end time of the window in seconds. + signal (np.ndarray): The IMU signal. + timestamps (np.ndarray): The timestamps corresponding to the IMU data. + sampling_rate (float): The sampling rate of the IMU data. + + Returns: + Dict[str, Any]: A dictionary containing the timestamp, signal, and sampling rate + for the specified time window. + """ start_id = bisect_left(timestamps, window_start * 1000) end_id = bisect_left(timestamps, window_end * 1000) if end_id == len(timestamps): @@ -126,12 +176,34 @@ def _get_imu_window( return sample_dict def get_imu(self, video_uid: str) -> Tuple[np.ndarray, np.ndarray, int]: + """ + Get the IMU data for a specified video UID. + + Args: + video_uid (str): The video UID for which to retrieve IMU data. + + Returns: + Tuple[np.ndarray, np.ndarray, int]: A tuple containing the IMU signal, timestamps, + and sampling rate for the specified video UID. + """ # Caching/etc? return self._load_imu(video_uid) def get_imu_sample( self, video_uid: str, video_start: float, video_end: float ) -> Dict[str, Any]: + """ + Get an IMU sample for a specified time window within a video. + + Args: + video_uid (str): The video UID for which to retrieve the IMU sample. + video_start (float): The start time of the video segment in seconds. + video_end (float): The end time of the video segment in seconds. + + Returns: + Dict[str, Any]: A dictionary containing the timestamp, signal, and sampling rate + for the specified time window. + """ # Assumes video clips are loaded sequentially, will lazy load imu if not self.imu_video_uid or video_uid != self.imu_video_uid: self.imu_video_uid = video_uid @@ -432,6 +504,19 @@ def __init__( self.video_path_handler = video_path_handler def check_IMU(self, input_dict: Dict[str, Any]) -> bool: + """ + Checks if the IMU data in the input dictionary is valid. + + Args: + input_dict (Dict[str, Any]): A dictionary containing IMU data and other information. + + Returns: + bool: True if the IMU data is problematic, False otherwise. + + This function checks several conditions to determine if the IMU data is problematic. + If any of the conditions are met, it logs a warning and returns True. Otherwise, it + returns False. + """ if ( len(input_dict["imu"]["signal"].shape) != 2 or input_dict["imu"]["signal"].shape[0] == 0 @@ -444,6 +529,20 @@ def check_IMU(self, input_dict: Dict[str, Any]) -> bool: return False def _transform_mm(self, sample_dict: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """ + Transforms a sample dictionary for model processing. + + Args: + sample_dict (Dict[str, Any]): A dictionary containing sample data. + + Returns: + Optional[Dict[str, Any]]: A transformed dictionary or None if transformation fails. + + This function transforms a sample dictionary for model processing. It checks and + manipulates various aspects of the dictionary, including video, audio, labels, + and IMU data, if available. The transformed dictionary is returned, or None if + transformation fails. + """ log.info("_transform_mm") with profiler.record_function("_transform_mm"): video_uid = sample_dict["video_uid"] @@ -509,9 +608,11 @@ def _transform_mm(self, sample_dict: Dict[str, Any]) -> Optional[Dict[str, Any]] # pyre-ignore def _video_transform(self): """ - This function contains example transforms using both PyTorchVideo and - TorchVision in the same callable. For 'train' model, we use augmentations (prepended - with 'Random'), for 'val' we use the respective deterministic function + Defines video transformations for data augmentation. + + This function contains example transforms using both PyTorchVideo and TorchVision + in the same callable. For 'train' mode, it applies augmentations (prepended + with 'Random'), and for 'val', it uses deterministic functions. """ assert ( @@ -544,6 +645,19 @@ def _video_transform(self): return Compose([video_transforms]) def signal_transform(self, type: str = "spectrogram", sample_rate: int = 48000): + """ + Defines signal transformations for audio data. + + Args: + type (str): The type of signal transformation to apply. + sample_rate (int): The sample rate of the audio data. + + Returns: + transform: A torchaudio transform for the specified type. + + This function defines signal transformations for audio data, including spectrogram, + mel spectrogram, and MFCC transformations, based on the specified type. + """ if type == "spectrogram": n_fft = 1024 win_length = None @@ -599,6 +713,20 @@ def signal_transform(self, type: str = "spectrogram", sample_rate: int = 48000): return transform def _preproc_audio(self, audio, audio_fps) -> Dict[str, Any]: + """ + Preprocesses audio data. + + Args: + audio: The audio data. + audio_fps: The audio sample rate. + + Returns: + Dict[str, Any]: A dictionary containing preprocessed audio data. + + This function preprocesses audio data, converting stereo to mono and applying + signal transformations such as spectrogram, mel spectrogram, or MFCC based on the + specified audio transformation type. + """ # convert stero to mono # https://github.com/pytorch/audio/issues/363 waveform_mono = torch.mean(audio, dim=0, keepdim=True) @@ -612,6 +740,18 @@ def _preproc_audio(self, audio, audio_fps) -> Dict[str, Any]: } def convert_one_hot(self, label_list: List[str]) -> List[int]: + """ + Converts a list of labels to one-hot encoding. + + Args: + label_list (List[str]): A list of labels. + + Returns: + List[int]: A list representing one-hot encoding for the labels. + + This function converts a list of labels to one-hot encoding based on a predefined + label-to-ID mapping. + """ labels = [x for x in label_list if x in self.label_name_id_map.keys()] assert len(labels) == len( label_list diff --git a/pytorchvideo/data/ego4d/utils.py b/pytorchvideo/data/ego4d/utils.py index 186004fd..6d62a36a 100644 --- a/pytorchvideo/data/ego4d/utils.py +++ b/pytorchvideo/data/ego4d/utils.py @@ -18,7 +18,25 @@ def check_window_len( s_time: float, e_time: float, w_len: float, video_dur: float ) -> Tuple[float, float]: """ - Constrain/slide the give time window to `w_len` size and the video/clip length. + Constrain or slide the given time window to match a specified length `w_len` while + considering the video or clip duration. + + Args: + s_time (float): The start time of the original time window. + e_time (float): The end time of the original time window. + w_len (float): The desired length of the time window. + video_dur (float): The duration of the video or clip. + + Returns: + Tuple[float, float]: A tuple containing the adjusted start and end times. + + This function adjusts the time window defined by `s_time` and `e_time` to match + a specified length `w_len`. If the time window is larger or smaller than `w_len`, + it is adjusted by equally extending or trimming the interior. If the adjusted + time window exceeds the duration of the video or clip, it is further adjusted to + stay within the video duration. + + Note: The function ensures that the adjusted time window has a length close to `w_len`. """ # adjust to match w_len interval = e_time - s_time @@ -50,11 +68,18 @@ def check_window_len( # TODO: Move to FixedClipSampler? class MomentsClipSampler(ClipSampler): """ - ClipSampler for Ego4d moments. Will return a fixed `window_sec` window - around the given annotation, shifting where relevant to account for the end - of the clip/video. + ClipSampler for Ego4d moments. This sampler returns a fixed-duration `window_sec` + window around a given annotation, adjusting for the end of the clip/video if necessary. + + The `clip_start` and `clip_end` fields are added to the annotation dictionary to + facilitate future lookups. - clip_start/clip_end is added to the annotation dict to facilitate future lookups. + Args: + window_sec (float): The duration (in seconds) of the fixed window to sample. + + This ClipSampler is designed for Ego4d moments and ensures that clips are sampled + with a fixed duration specified by `window_sec`. It adjusts the window's position + if needed to account for the end of the clip or video. """ def __init__(self, window_sec: float = 0) -> None: @@ -93,6 +118,17 @@ def __call__( def get_label_id_map(label_id_map_path: str) -> Dict[str, int]: + """ + Reads a label-to-ID mapping from a JSON file. + + Args: + label_id_map_path (str): The path to the label ID mapping JSON file. + + Returns: + Dict[str, int]: A dictionary mapping label names to their corresponding IDs. + + This function reads a JSON file containing label-to-ID mapping and returns it as a dictionary. + """ label_name_id_map: Dict[str, int] try: @@ -108,17 +144,52 @@ def get_label_id_map(label_id_map_path: str) -> Dict[str, int]: class Ego4dImuDataBase(ABC): """ Base class placeholder for Ego4d IMU data. + + This is a base class for handling Ego4d IMU data. It defines the required interface for + checking if IMU data is available for a video and retrieving IMU samples. """ def __init__(self, basepath: str): + """ + Initializes an instance of Ego4dImuDataBase. + + Args: + basepath (str): The base path for Ego4d IMU data. + """ self.basepath = basepath @abstractmethod def has_imu(self, video_uid: str) -> bool: + """ + Checks if IMU data is available for a video. + + Args: + video_uid (str): The unique identifier of the video. + + Returns: + bool: True if IMU data is available, False otherwise. + + This method should be implemented to check if IMU data exists for a specific video + identified by its unique ID. + """ pass @abstractmethod def get_imu_sample( self, video_uid: str, video_start: float, video_end: float ) -> Dict[str, Any]: + """ + Retrieves an IMU sample for a video segment. + + Args: + video_uid (str): The unique identifier of the video. + video_start (float): The start time of the video segment. + video_end (float): The end time of the video segment. + + Returns: + Dict[str, Any]: A dictionary containing IMU data. + + This method should be implemented to retrieve IMU data for a specific video segment + identified by its unique ID and time range. + """ pass diff --git a/pytorchvideo/data/encoded_video.py b/pytorchvideo/data/encoded_video.py index 227227ad..7b3566b1 100644 --- a/pytorchvideo/data/encoded_video.py +++ b/pytorchvideo/data/encoded_video.py @@ -16,10 +16,13 @@ def select_video_class(decoder: str) -> Video: """ - Select the class for accessing clips based on provided decoder string + Select the class for accessing clips based on the provided decoder string. Args: - decoder (str): Defines what type of decoder used to decode a video. + decoder (str): Defines what type of decoder is used to decode a video. + + Returns: + Video: An instance of the selected video decoding class. """ if DecoderType(decoder) == DecoderType.PYAV: from .encoded_video_pyav import EncodedVideoPyAV @@ -55,11 +58,14 @@ def from_path( **other_args: Dict[str, Any], ): """ - Fetches the given video path using PathManager (allowing remote uris to be + Fetches the given video path using PathManager (allowing remote URIs to be fetched) and constructs the EncodedVideo object. Args: - file_path (str): a PathManager file-path. + file_path (str): A PathManager file path. + + Returns: + EncodedVideo: An instance of the EncodedVideo class. """ # We read the file with PathManager so that we can read from remote uris. with g_pathmgr.open(file_path, "rb") as fh: diff --git a/pytorchvideo/data/encoded_video_decord.py b/pytorchvideo/data/encoded_video_decord.py index 5ae85dc0..65e713fc 100644 --- a/pytorchvideo/data/encoded_video_decord.py +++ b/pytorchvideo/data/encoded_video_decord.py @@ -119,20 +119,28 @@ def __init__( @property def name(self) -> Optional[str]: """ + Get the name of the stored video if set. + Returns: - name: the name of the stored video if set. + name (Optional[str]): The name of the video. """ return self._video_name @property def duration(self) -> float: """ + Get the video's duration/end-time in seconds. + Returns: - duration: the video's duration/end-time in seconds. + duration (float): The video's duration. """ return self._duration def close(self): + """ + Close the video reader. + """ + if self._av_reader is not None: del self._av_reader self._av_reader = None diff --git a/pytorchvideo/data/encoded_video_pyav.py b/pytorchvideo/data/encoded_video_pyav.py index 8e952381..036275ca 100644 --- a/pytorchvideo/data/encoded_video_pyav.py +++ b/pytorchvideo/data/encoded_video_pyav.py @@ -326,18 +326,20 @@ def _pyav_decode_stream( perform_seek: bool = True, ) -> Tuple[List, float]: """ - Decode the video with PyAV decoder. + Decode video frames from a PyAV container using a specified stream. + Args: - container (container): PyAV container. - start_pts (int): the starting Presentation TimeStamp to fetch the - video frames. - end_pts (int): the ending Presentation TimeStamp of the decoded frames. - stream (stream): PyAV stream. - stream_name (dict): a dictionary of streams. For example, {"video": 0} - means video stream at stream index 0. + container (av.container.input.InputContainer): PyAV container containing the video. + start_pts (int): The starting Presentation TimeStamp (PTS) to fetch video frames. + end_pts (int): The ending Presentation TimeStamp (PTS) of the decoded frames. + stream (av.video.stream.VideoStream): PyAV video stream to decode. + stream_name (dict): A dictionary of streams, e.g., {"video": 0} for the video stream at index 0. + buffer_size (int): Size of the frame buffer (unused in this function). + perform_seek (bool): Whether to perform seeking in the stream (may affect performance). + Returns: - result (list): list of decoded frames. - max_pts (int): max Presentation TimeStamp of the video sequence. + result (List[av.video.frame.VideoFrame]): List of decoded video frames. + max_pts (float): The maximum Presentation TimeStamp (PTS) of the video sequence. """ # Seeking in the stream is imprecise. Thus, seek to an earlier pts by a diff --git a/pytorchvideo/data/encoded_video_torchvision.py b/pytorchvideo/data/encoded_video_torchvision.py index eee8f17a..2114ef55 100644 --- a/pytorchvideo/data/encoded_video_torchvision.py +++ b/pytorchvideo/data/encoded_video_torchvision.py @@ -34,6 +34,15 @@ def __init__( decode_video: bool = True, decode_audio: bool = True, ) -> None: + """ + Initialize an EncodedVideoTorchVision object. + + Args: + file (BinaryIO): A file-like object containing the encoded video. + video_name (str, optional): An optional name assigned to the video. + decode_video (bool): Whether to decode the video. + decode_audio (bool): Whether to decode the audio. + """ if not decode_video: raise NotImplementedError() @@ -77,20 +86,27 @@ def __init__( @property def name(self) -> Optional[str]: """ + Get the name of the stored video. + Returns: - name: the name of the stored video if set. + str or None: The video's name if set, otherwise None. """ return self._video_name @property def duration(self) -> float: """ + Get the video's duration in seconds. + Returns: - duration: the video's duration/end-time in seconds. + float: The video's duration in seconds. """ return self._duration def close(self): + """ + Close the video (not implemented). + """ pass def get_clip( @@ -184,7 +200,14 @@ def _torch_vision_decode_video( self, start_pts: int = 0, end_pts: int = -1 ) -> float: """ - Decode the video in the PTS range [start_pts, end_pts] + Decode the video in the specified PTS range. + + Args: + start_pts (int): The starting Presentation TimeStamp (PTS) to decode. + end_pts (int): The ending Presentation TimeStamp (PTS) to decode. + + Returns: + tuple: A tuple containing video and audio data as well as other information. """ video_and_pts = None audio_and_pts = None diff --git a/pytorchvideo/data/epic_kitchen/epic_kitchen_dataset.py b/pytorchvideo/data/epic_kitchen/epic_kitchen_dataset.py index 6077517b..8af62090 100644 --- a/pytorchvideo/data/epic_kitchen/epic_kitchen_dataset.py +++ b/pytorchvideo/data/epic_kitchen/epic_kitchen_dataset.py @@ -22,33 +22,54 @@ @dataclass class ActionData(DataclassFieldCaster): """ - Class representing an action from the Epic Kitchen dataset. + Represents an action from the Epic Kitchen dataset. + + This class encapsulates information about actions performed in the dataset, + including participant ID, video ID, narration, timestamps, verb, noun, and more. + + Attributes: + participant_id (str): The unique identifier of the participant. + video_id (str): The identifier of the video containing the action. + narration (str): The textual narration of the action. + start_timestamp (str): The start timestamp of the action in HH:MM:SS format. + stop_timestamp (str): The stop timestamp of the action in HH:MM:SS format. + start_frame (int): The starting frame of the action. + stop_frame (int): The ending frame of the action. + verb (str): The verb describing the action. + verb_class (int): The class label for the verb. + noun (str): The noun associated with the action. + noun_class (int): The class label for the noun. + all_nouns (list): A list of all nouns related to the action. + all_noun_classes (list): A list of class labels for all related nouns. + + Properties: + start_time (float): The start time of the action in seconds. + stop_time (float): The stop time of the action in seconds. + + Methods: + None + + Note: + This class is designed to represent actions from the Epic Kitchen dataset. """ - - participant_id: str - video_id: str - narration: str - start_timestamp: str - stop_timestamp: str - start_frame: int - stop_frame: int - verb: str - verb_class: int - noun: str - noun_class: int - all_nouns: list = DataclassFieldCaster.complex_initialized_dataclass_field( - ast.literal_eval - ) - all_noun_classes: list = DataclassFieldCaster.complex_initialized_dataclass_field( - ast.literal_eval - ) - @property def start_time(self) -> float: + """ + Get the start time of the action in seconds. + + Returns: + float: The start time in seconds. + """ return get_seconds_from_hms_time(self.start_timestamp) @property def stop_time(self) -> float: + """ + Get the stop time of the action in seconds. + + Returns: + float: The stop time in seconds. + """ return get_seconds_from_hms_time(self.stop_timestamp) diff --git a/pytorchvideo/data/epic_kitchen/utils.py b/pytorchvideo/data/epic_kitchen/utils.py index 3dedff09..55e5984f 100644 --- a/pytorchvideo/data/epic_kitchen/utils.py +++ b/pytorchvideo/data/epic_kitchen/utils.py @@ -11,22 +11,32 @@ def build_frame_manifest_from_flat_directory( data_directory_path: str, multithreaded: bool ) -> Dict[str, VideoFrameInfo]: """ + Builds a manifest of video frame information from a flat directory structure. + Args: - data_directory_path (str): Path or URI to EpicKitchenDataset data. - Data at this path must be a folder of structure: - { - "{video_id}": [ - "frame_{frame_number}.{file_extension}", - "frame_{frame_number}.{file_extension}", - "frame_{frame_number}.{file_extension}", - ...] - ...} + data_directory_path (str): + The path or URI to the EpicKitchenDataset data. Data at this path must be organized as follows: + { + "{video_id}": [ + "frame_{frame_number}.{file_extension}", + "frame_{frame_number}.{file_extension}", + "frame_{frame_number}.{file_extension}", + ... + ] + } + multithreaded (bool): - controls whether io operations are performed across multiple threads. + Controls whether I/O operations are performed across multiple threads. Returns: - Dictionary mapping video_id of available videos to the locations of their + Dict[str, VideoFrameInfo]: A dictionary mapping the video_id of available videos to the locations of their underlying frame files. + + This function iterates through the provided data directory, identifies video frames, and builds a manifest + containing information about each video's frame files. It returns a dictionary where each video_id is associated + with a VideoFrameInfo object that includes details about the frames for that video. + + Note: This function assumes a specific directory structure and naming conventions for frame files. """ video_frames = {} @@ -87,23 +97,33 @@ def build_frame_manifest_from_nested_directory( data_directory_path: str, multithreaded: bool ) -> Dict[str, VideoFrameInfo]: """ - Args: - data_directory_path (str): Path or URI to EpicKitchenDataset data. - If this dataset is to load from the frame-based dataset: - Data at this path must be a folder of structure: - { - "{participant_id}" : [ - "{participant_id}_{participant_video_id}_{frame_number}.{file_extension}", + Builds a manifest of video frame information from a nested directory structure. - ...], - ...} + Args: + data_directory_path (str): + The path or URI to the EpicKitchenDataset data. If this dataset is intended to load from a + frame-based dataset, the data at this path must be organized in a specific nested structure: + + { + "{participant_id}": [ + "{participant_id}_{participant_video_id}_{frame_number}.{file_extension}", + ... + ], + ... + } multithreaded (bool): - controls whether io operations are performed across multiple threads. + Controls whether I/O operations are performed across multiple threads. - Returns: - Dictionary mapping video_id of available videos to the locations of their - underlying frame files. + Returns: + Dict[str, VideoFrameInfo]: + A dictionary mapping the video_id of available videos to the locations of their underlying frame files. + + This function iterates through the provided data directory, identifies video frames, and builds a manifest + containing information about each video's frame files. It returns a dictionary where each video_id is associated + with a VideoFrameInfo object that includes details about the frames for that video. + + Note: This function assumes a specific directory structure and naming conventions for frame files. """ participant_ids = g_pathmgr.ls(str(data_directory_path)) @@ -174,16 +194,20 @@ def build_encoded_manifest_from_nested_directory( data_directory_path: str, ) -> Dict[str, EncodedVideoInfo]: """ - Creates a dictionary from video_id to EncodedVideoInfo for - encoded videos in the given directory. + Creates a dictionary mapping video_id to EncodedVideoInfo for encoded videos in the given directory. Args: - data_directory_path (str): The folder to ls to find encoded - video files. + data_directory_path (str): + The folder to list to find encoded video files. Returns: - Dict[str, EncodedVideoInfo] mapping video_id to EncodedVideoInfo - for each file in 'data_directory_path' + Dict[str, EncodedVideoInfo]: + A dictionary mapping video_id to EncodedVideoInfo for each file in 'data_directory_path'. + + This function scans the provided data directory to identify encoded video files and creates a dictionary + where each video_id is associated with an EncodedVideoInfo object containing information about the video. + + Note: This function assumes a specific naming convention for video files and the structure of the data directory. """ encoded_video_infos = {} for participant_id in g_pathmgr.ls(data_directory_path): diff --git a/pytorchvideo/data/epic_kitchen_forecasting.py b/pytorchvideo/data/epic_kitchen_forecasting.py index 8a6ad5e6..ffa51f7a 100644 --- a/pytorchvideo/data/epic_kitchen_forecasting.py +++ b/pytorchvideo/data/epic_kitchen_forecasting.py @@ -147,19 +147,34 @@ def _transform_generator( num_input_clips: int, ) -> Callable[[Dict[str, Any]], Dict[str, Any]]: """ + Generate a custom transform function for video clips. + Args: - transform (Callable[[Dict[str, Any]], Dict[str, Any]]): A function that performs - any operation on a clip before it is returned in the default transform function. - num_forecast_actions: (int) The number of actions to be included in the - action vector. - frames_per_clip (int): The number of frames per clip to sample. - num_input_clips (int): The number of subclips to be included in the video data. + transform (Callable[[Dict[str, Any]], Dict[str, Any]]): + A function that performs any operation on a clip before it is returned + in the default transform function. + num_forecast_actions (int): + The number of actions to be included in the action vector. + frames_per_clip (int): + The number of frames per clip to sample. + num_input_clips (int): + The number of subclips to be included in the video data. Returns: - A function that performs any operation on a clip and returns the transformed clip. + Callable[[Dict[str, Any]], Dict[str, Any]]: + A function that performs any operation on a clip and returns the transformed clip. """ def transform_clip(clip: Dict[str, Any]) -> Dict[str, Any]: + """ + Transform a video clip according to the specified parameters. + + Args: + clip (Dict[str, Any]): The clip to be transformed. + + Returns: + Dict[str, Any]: The transformed clip. + """ assert all( clip["actions"][i].start_time <= clip["actions"][i + 1].start_time for i in range(len(clip["actions"]) - 1) @@ -199,20 +214,35 @@ def _frame_filter_generator( num_input_clips: int, ) -> Callable[[List[int]], List[int]]: """ + Generate a frame filter function for subclip sampling. + Args: - frames_per_clip (int): The number of frames per clip to sample. - seconds_per_clip (float): The length of each sampled subclip in seconds. - clip_time_stride (float): The time difference in seconds between the start of - each input subclip. - num_input_clips (int): The number of subclips to be included in the video data. + frames_per_clip (int): + The number of frames per clip to sample. + seconds_per_clip (float): + The length of each sampled subclip in seconds. + clip_time_stride (float): + The time difference in seconds between the start of each input subclip. + num_input_clips (int): + The number of subclips to be included in the video data. Returns: - A function that takes in a list of frame indicies and outputs a subsampled list. + Callable[[List[int]], List[int]]: + A function that takes in a list of frame indices and outputs a subsampled list. """ time_window_length = seconds_per_clip + (num_input_clips - 1) * clip_time_stride desired_frames_per_second = frames_per_clip / seconds_per_clip def frame_filter(frame_indices: List[int]) -> List[int]: + """ + Filter a list of frame indices to subsample subclips. + + Args: + frame_indices (List[int]): The list of frame indices to filter. + + Returns: + List[int]: The subsampled list of frame indices. + """ num_available_frames_for_all_clips = len(frame_indices) available_frames_per_second = ( num_available_frames_for_all_clips / time_window_length @@ -242,19 +272,24 @@ def _define_clip_structure_generator( num_forecast_actions: int, ) -> Callable[[Dict[str, Video], Dict[str, List[ActionData]]], List[VideoClipInfo]]: """ + Generate a clip structure defining function based on sampling strategy. + Args: - clip_sampling (ClipSampling): - The type of sampling to perform to perform on the videos of the dataset. - seconds_per_clip (float): The length of each sampled clip in seconds. - clip_time_stride: The time difference in seconds between the start of - each input subclip. - num_input_clips (int): The number of subclips to be included in the video data. - num_forecast_actions (int): The number of actions to be included in the - action vector. + clip_sampling (str): + The type of sampling to perform on the videos of the dataset. + seconds_per_clip (float): + The length of each sampled clip in seconds. + clip_time_stride (float): + The time difference in seconds between the start of each input subclip. + num_input_clips (int): + The number of subclips to be included in the video data. + num_forecast_actions (int): + The number of actions to be included in the action vector. Returns: - A function that takes a dictionary of videos and outputs a list of sampled - clips. + Callable[[Dict[str, Video], Dict[str, List[ActionData]]], List[VideoClipInfo]]: + A function that takes a dictionary of videos and their associated actions, + and outputs a list of sampled video clip information. """ # TODO(T77683480) if not clip_sampling == ClipSampling.Random: @@ -268,6 +303,19 @@ def _define_clip_structure_generator( def define_clip_structure( videos: Dict[str, Video], video_actions: Dict[str, List[ActionData]] ) -> List[VideoClipInfo]: + """ + Define the structure of video clips based on specified parameters. + + Args: + videos (Dict[str, Video]): + A dictionary of videos indexed by video ID. + video_actions (Dict[str, List[ActionData]]): + A dictionary of video actions indexed by video ID. + + Returns: + List[VideoClipInfo]: + A list of VideoClipInfo objects representing the sampled video clips. + """ candidate_sample_clips = [] for video_id, actions in video_actions.items(): for i, action in enumerate(actions[: (-1 * num_forecast_actions)]): diff --git a/pytorchvideo/data/epic_kitchen_recognition.py b/pytorchvideo/data/epic_kitchen_recognition.py index 8a6f688e..157898c7 100644 --- a/pytorchvideo/data/epic_kitchen_recognition.py +++ b/pytorchvideo/data/epic_kitchen_recognition.py @@ -179,14 +179,18 @@ def _define_clip_structure_generator( seconds_per_clip: float, clip_sampling: ClipSampling ) -> Callable[[Dict[str, Video], Dict[str, List[ActionData]]], List[VideoClipInfo]]: """ + Generate a clip structure defining function based on specified parameters. + Args: - seconds_per_clip (float): The length of each sampled clip in seconds. + seconds_per_clip (float): + The length of each sampled clip in seconds. clip_sampling (ClipSampling): - The type of sampling to perform to perform on the videos of the dataset. + The type of sampling to perform on the videos of the dataset. Returns: - A function that takes a dictionary of videos and a dictionary of the actions - for each video and outputs a list of sampled clips. + Callable[[Dict[str, Video], Dict[str, List[ActionData]]], List[VideoClipInfo]]: + A function that takes a dictionary of videos and their associated actions, + and outputs a list of sampled video clip information. """ if not clip_sampling == ClipSampling.RandomOffsetUniform: raise NotImplementedError( @@ -197,6 +201,19 @@ def _define_clip_structure_generator( def define_clip_structure( videos: Dict[str, Video], actions: Dict[str, List[ActionData]] ) -> List[VideoClipInfo]: + """ + Define the structure of video clips based on specified parameters. + + Args: + videos (Dict[str, Video]): + A dictionary of videos indexed by video ID. + actions (Dict[str, List[ActionData]]): + A dictionary of video actions indexed by video ID. + + Returns: + List[VideoClipInfo]: + A list of VideoClipInfo objects representing the sampled video clips. + """ clips = [] for video_id, video in videos.items(): offset = random.random() * seconds_per_clip diff --git a/pytorchvideo/data/frame_video.py b/pytorchvideo/data/frame_video.py index d3aacf2f..a5df8362 100644 --- a/pytorchvideo/data/frame_video.py +++ b/pytorchvideo/data/frame_video.py @@ -32,9 +32,9 @@ class FrameVideo(Video): """ - FrameVideo is an abstractions for accessing clips based on their start and end + FrameVideo is an abstraction for accessing clips based on their start and end time for a video where each frame is stored as an image. PathManager is used for - frame image reading, allowing non-local uri's to be used. + frame image reading, allowing non-local URIs to be used. """ def __init__( @@ -46,15 +46,22 @@ def __init__( multithreaded_io: bool = False, ) -> None: """ + Initialize a FrameVideo object. + Args: - duration (float): the duration of the video in seconds. - fps (float): the target fps for the video. This is needed to link the frames - to a second timestamp in the video. - video_frame_to_path_fn (Callable[[int], str]): a function that maps from a frame - index integer to the file path where the frame is located. - video_frame_paths (List[str]): Dictionary of frame paths for each index of a video. - multithreaded_io (bool): controls whether parllelizable io operations are - performed across multiple threads. + duration (float): + The duration of the video in seconds. + fps (float): + The target FPS for the video. This is needed to link the frames to a second + timestamp in the video. + video_frame_to_path_fn (Callable[[int], str], optional): + A function that maps from a frame index integer to the file path where the + frame is located. + video_frame_paths (List[str], optional): + List of frame paths for each index of a video. + multithreaded_io (bool, optional): + Controls whether parallelizable IO operations are performed across multiple + threads. """ if not _HAS_CV2: raise ImportError( @@ -86,15 +93,24 @@ def from_directory( path_order_cache: Optional[Dict[str, List[str]]] = None, ): """ + Create a FrameVideo object from a directory containing frame images. + Args: - path (str): path to frame video directory. - fps (float): the target fps for the video. This is needed to link the frames - to a second timestamp in the video. - multithreaded_io (bool): controls whether parllelizable io operations are - performed across multiple threads. - path_order_cache (dict): An optional mapping from directory-path to list - of frames in the directory in numerical order. Used for speedup by - caching the frame paths. + path (str): + Path to the frame video directory. + fps (float, optional): + The target FPS for the video. This is needed to link the frames to a second + timestamp in the video. + multithreaded_io (bool, optional): + Controls whether parallelizable IO operations are performed across multiple + threads. + path_order_cache (dict, optional): + An optional mapping from directory path to list of frames in the directory + in numerical order. Used for speedup by caching the frame paths. + + Returns: + FrameVideo: + A FrameVideo object created from the provided frame directory. """ if path_order_cache is not None and path in path_order_cache: return cls.from_frame_paths(path_order_cache[path], fps, multithreaded_io) @@ -119,12 +135,21 @@ def from_frame_paths( multithreaded_io: bool = False, ): """ + Create a FrameVideo object from a list of frame image paths. + Args: - video_frame_paths (List[str]): a list of paths to each frames in the video. - fps (float): the target fps for the video. This is needed to link the frames - to a second timestamp in the video. - multithreaded_io (bool): controls whether parllelizable io operations are - performed across multiple threads. + video_frame_paths (List[str]): + A list of paths to each frame in the video. + fps (float, optional): + The target FPS for the video. This is needed to link the frames to a second + timestamp in the video. + multithreaded_io (bool, optional): + Controls whether parallelizable IO operations are performed across multiple + threads. + + Returns: + FrameVideo: + A FrameVideo object created from the provided frame image paths. """ assert len(video_frame_paths) != 0, "video_frame_paths is empty" return cls( @@ -136,13 +161,21 @@ def from_frame_paths( @property def name(self) -> float: + """ + Returns the name of the FrameVideo. + + Returns: + str: The name of the FrameVideo. + """ return self._name @property def duration(self) -> float: """ + Returns the duration of the FrameVideo. + Returns: - duration: the video's duration/end-time in seconds. + float: The duration of the FrameVideo in seconds. """ return self._duration @@ -158,31 +191,31 @@ def get_clip( """ Retrieves frames from the stored video at the specified start and end times in seconds (the video always starts at 0 seconds). Returned frames will be - in [start_sec, end_sec). Given that PathManager may - be fetching the frames from network storage, to handle transient errors, frame - reading is retried N times. Note that as end_sec is exclusive, so you may need - to use `get_clip(start_sec, duration + EPS)` to get the last frame. + in [start_sec, end_sec). Given that PathManager may be fetching the frames + from network storage, to handle transient errors, frame reading is retried N times. + Note that as end_sec is exclusive, so you may need to use `get_clip(start_sec, duration + EPS)` + to get the last frame. Args: - start_sec (float): the clip start time in seconds - end_sec (float): the clip end time in seconds - frame_filter (Optional[Callable[List[int], List[int]]]): - function to subsample frames in a clip before loading. - If None, no subsampling is peformed. - Returns: - clip_frames: A tensor of the clip's RGB frames with shape: - (channel, time, height, width). The frames are of type torch.float32 and - in the range [0 - 255]. Raises an exception if unable to load images. - - clip_data: - "video": A tensor of the clip's RGB frames with shape: - (channel, time, height, width). The frames are of type torch.float32 and - in the range [0 - 255]. Raises an exception if unable to load images. - - "frame_indices": A list of indices for each frame relative to all frames in the - video. + start_sec (float): + The clip start time in seconds. + end_sec (float): + The clip end time in seconds. + frame_filter (Optional[Callable[List[int], List[int]]], optional): + Function to subsample frames in a clip before loading. + If None, no subsampling is performed. - Returns None if no frames are found. + Returns: + Dict[str, Optional[torch.Tensor]]: + A dictionary containing the following keys: + - "video": A tensor of the clip's RGB frames with shape: + (channel, time, height, width). The frames are of type torch.float32 and + in the range [0 - 255]. + - "frame_indices": A list of indices for each frame relative to all frames in the + video. + - "audio": None (audio is not supported in FrameVideo). + + Returns None if no frames are found. """ if start_sec < 0 or start_sec > self._duration: logger.warning( @@ -224,16 +257,24 @@ def _load_images_with_retries( image_paths: List[str], num_retries: int = 10, multithreaded: bool = True ) -> torch.Tensor: """ - Loads the given image paths using PathManager, decodes them as RGB images and - returns them as a stacked tensors. + Loads the given image paths using PathManager, decodes them as RGB images, and + returns them as a stacked tensor. + Args: - image_paths (List[str]): a list of paths to images. - num_retries (int): number of times to retry image reading to handle transient error. - multithreaded (bool): if images are fetched via multiple threads in parallel. + image_paths (List[str]): + A list of paths to images. + num_retries (int, optional): + Number of times to retry image reading to handle transient errors. + multithreaded (bool, optional): + If images are fetched via multiple threads in parallel. + Returns: - A tensor of the clip's RGB frames with shape: - (time, height, width, channel). The frames are of type torch.uint8 and - in the range [0 - 255]. Raises an exception if unable to load images. + torch.Tensor: + A tensor of the clip's RGB frames with shape: (time, height, width, channel). + The frames are of type torch.uint8 and in the range [0 - 255]. + + Raises: + Exception: If unable to load images. """ imgs = [None for i in image_paths] diff --git a/pytorchvideo/data/hmdb51.py b/pytorchvideo/data/hmdb51.py index eb87eb3b..2ebc0375 100644 --- a/pytorchvideo/data/hmdb51.py +++ b/pytorchvideo/data/hmdb51.py @@ -178,16 +178,16 @@ def Hmdb51( decoder: str = "pyav", ) -> LabeledVideoDataset: """ - A helper function to create ``LabeledVideoDataset`` object for HMDB51 dataset + A helper function to create a `LabeledVideoDataset` object for HMDB51 dataset. Args: data_path (pathlib.Path): Path to the data. The path type defines how the data should be read: - * For a file path, the file is read and each line is parsed into a + - For a file path, the file is read and each line is parsed into a video path and label. - * For a directory, the directory structure defines the classes - (i.e. each subdirectory is a class). + - For a directory, the directory structure defines the classes + (i.e., each subdirectory is a class). clip_sampler (ClipSampler): Defines how clips should be sampled from each video. See the clip sampling documentation for more information. @@ -197,20 +197,25 @@ def Hmdb51( if necessary, the distributed split. transform (Callable): This callable is evaluated on the clip output before - the clip is returned. It can be used for user defined preprocessing and - augmentations to the clips. See the ``LabeledVideoDataset`` class for + the clip is returned. It can be used for user-defined preprocessing and + augmentations to the clips. See the `LabeledVideoDataset` class for clip output format. - video_path_prefix (str): Path to root directory with the videos that are - loaded in LabeledVideoDataset. All the video paths before loading + video_path_prefix (str): Path to the root directory with the videos that are + loaded in the `LabeledVideoDataset`. All the video paths before loading are prefixed with this path. - split_id (int): Fold id to be loaded. Options are 1, 2 or 3 + split_id (int): Fold id to be loaded. Options are 1, 2, or 3. - split_type (str): Split/Fold type to be loaded. Options are ("train", "test" or - "unused") + split_type (str): Split/Fold type to be loaded. Options are "train", "test", or + "unused". + + decode_audio (bool): Whether to decode audio or not. decoder (str): Defines which backend should be used to decode videos. + + Returns: + LabeledVideoDataset: A dataset object for the HMDB51 dataset. """ torch._C._log_api_usage_once("PYTORCHVIDEO.dataset.Hmdb51") diff --git a/pytorchvideo/data/json_dataset.py b/pytorchvideo/data/json_dataset.py index c86c1b51..5d764834 100644 --- a/pytorchvideo/data/json_dataset.py +++ b/pytorchvideo/data/json_dataset.py @@ -221,7 +221,7 @@ class UntrimmedClipSampler: """ A wrapper for adapting untrimmed annotated clips from the json_dataset to the standard `pytorchvideo.data.ClipSampler` expected format. Specifically, for each - clip it uses the provided `clip_sampler` to sample between "clip_start_sec" and + clip, it uses the provided `clip_sampler` to sample between "clip_start_sec" and "clip_end_sec" from the json_dataset clip annotation. """ @@ -236,6 +236,17 @@ def __init__(self, clip_sampler: ClipSampler) -> None: def __call__( self, last_clip_time: float, video_duration: float, clip_info: Dict[str, Any] ) -> ClipInfo: + """ + Sample a trimmed clip based on the provided untrimmed clip info. + + Args: + last_clip_time (float): The end time of the last sampled clip. + video_duration (float): The duration of the entire video. + clip_info (Dict[str, Any]): Information about the untrimmed clip. + + Returns: + ClipInfo: Information about the sampled trimmed clip. + """ clip_start_boundary = clip_info["clip_start_sec"] clip_end_boundary = clip_info["clip_end_sec"] duration = clip_start_boundary - clip_end_boundary @@ -251,4 +262,8 @@ def __call__( ) def reset(self) -> None: + """ + Reset the state of the clip sampler. + """ pass + \ No newline at end of file diff --git a/pytorchvideo/data/kinetics.py b/pytorchvideo/data/kinetics.py index 6cdb3646..c7d0d6d8 100644 --- a/pytorchvideo/data/kinetics.py +++ b/pytorchvideo/data/kinetics.py @@ -43,7 +43,7 @@ def Kinetics( if necessary, the distributed split. transform (Callable): This callable is evaluated on the clip output before - the clip is returned. It can be used for user defined preprocessing and + the clip is returned. It can be used for user-defined preprocessing and augmentations to the clips. See the ``LabeledVideoDataset`` class for clip output format. @@ -53,7 +53,10 @@ def Kinetics( decode_audio (bool): If True, also decode audio from video. - decoder (str): Defines what type of decoder used to decode a video. + decoder (str): Defines what type of decoder is used to decode a video. + + Returns: + LabeledVideoDataset: The Kinetics dataset. """ diff --git a/pytorchvideo/data/utils.py b/pytorchvideo/data/utils.py index d90b6588..6846bb94 100644 --- a/pytorchvideo/data/utils.py +++ b/pytorchvideo/data/utils.py @@ -24,8 +24,13 @@ def thwc_to_cthw(data: torch.Tensor) -> torch.Tensor: """ - Permute tensor from (time, height, weight, channel) to - (channel, height, width, time). + Permute tensor from (time, height, weight, channel) to (channel, height, width, time). + + Args: + data (torch.Tensor): Input tensor of shape (time, height, weight, channel). + + Returns: + torch.Tensor: Permuted tensor of shape (channel, height, width, time). """ return data.permute(3, 0, 1, 2) @@ -37,11 +42,17 @@ def secs_to_pts( round_mode: str = "floor", ) -> int: """ - Converts a time (in seconds) to the given time base and start_pts offset - presentation time. Round_mode specifies the mode of rounding when converting time. + Converts a time (in seconds) to the given time base and start_pts offset presentation time. + Round_mode specifies the mode of rounding when converting time. + + Args: + time_in_seconds (float): Time in seconds. + time_base (float): Time base. + start_pts (int): Start presentation time. + round_mode (str): Rounding mode ("floor" or "ceil"). Returns: - pts (int): The time in the given time base. + int: Presentation time in the given time base. """ if time_in_seconds == math.inf: return math.inf @@ -56,10 +67,15 @@ def secs_to_pts( def pts_to_secs(pts: int, time_base: float, start_pts: int) -> float: """ - Converts a present time with the given time base and start_pts offset to seconds. + Converts a presentation time with the given time base and start_pts offset to seconds. + + Args: + pts (int): Presentation time. + time_base (float): Time base. + start_pts (int): Start presentation time. Returns: - time_in_seconds (float): The corresponding time in seconds. + float: Corresponding time in seconds. """ if pts == math.inf: return math.inf @@ -166,6 +182,10 @@ class MultiProcessSampler(torch.utils.data.Sampler): """ def __init__(self, sampler: torch.utils.data.Sampler) -> None: + """ + Args: + sampler (torch.utils.data.Sampler): The underlying PyTorch sampler to be split. + """ self._sampler = sampler def __iter__(self): @@ -184,7 +204,7 @@ def __iter__(self): if len(worker_split) == 0: logger.warning( f"More data workers({worker_info.num_workers}) than videos" - f"({len(self._sampler)}). For optimal use of processes " + f"({len(self._sampler)}). For optimal use of processes, " "reduce num_workers." ) return iter(()) @@ -194,26 +214,28 @@ def __iter__(self): worker_sampler = itertools.islice(iter(self._sampler), iter_start, iter_end) else: - # If no worker processes found, we return the full sampler. + # If no worker processes are found, we return the full sampler. worker_sampler = iter(self._sampler) return worker_sampler def optional_threaded_foreach( - target: Callable, args_iterable: Iterable[Tuple], multithreaded: bool + target: Callable, + args_iterable: Iterable[Tuple], + multithreaded: bool ): """ Applies 'target' function to each Tuple args in 'args_iterable'. - If 'multithreaded' a thread is spawned for each function application. + If 'multithreaded' is True, a thread is spawned for each function application. Args: target (Callable): A function that takes as input the parameters in each args_iterable Tuple. args_iterable (Iterable[Tuple]): - An iterable of the tuples each containing a set of parameters to pass to - target. + An iterable of tuples, each containing a set of parameters to pass to + the target function. multithreaded (bool): Whether or not the target applications are parallelized by thread. diff --git a/pytorchvideo/data/video.py b/pytorchvideo/data/video.py index 5d894eb8..0e717955 100644 --- a/pytorchvideo/data/video.py +++ b/pytorchvideo/data/video.py @@ -21,6 +21,22 @@ def __init__(self) -> None: def video_from_path( self, filepath, decode_video=True, decode_audio=False, decoder="pyav", fps=30 ): + """ + Returns a video object (either EncodedVideo or FrameVideo) based on the provided file path. + + Args: + filepath (str): Path to the video file or directory containing frame images. + decode_video (bool): Whether to decode the video (only for EncodedVideo). + decode_audio (bool): Whether to decode the audio (only for EncodedVideo). + decoder (str): The video decoder to use (only for EncodedVideo). + fps (int): Frames per second (only for FrameVideo). + + Returns: + Union[EncodedVideo, FrameVideo]: A video object based on the provided file path. + + Raises: + FileNotFoundError: If the file or directory specified by `filepath` does not exist. + """ try: is_file = g_pathmgr.isfile(filepath) is_dir = g_pathmgr.isdir(filepath) @@ -64,9 +80,13 @@ def __init__( decode_audio: bool = True, ) -> None: """ + Initializes the Video object with a file-like object containing the encoded video. + Args: - file (BinaryIO): a file-like object (e.g. io.BytesIO or io.StringIO) that + file (BinaryIO): A file-like object (e.g. io.BytesIO or io.StringIO) that contains the encoded video. + video_name (Optional[str]): An optional name for the video. + decode_audio (bool): Whether to decode audio from the video. """ pass @@ -74,8 +94,10 @@ def __init__( @abstractmethod def duration(self) -> float: """ + Returns the duration of the video in seconds. + Returns: - duration of the video in seconds + float: The duration of the video in seconds. """ pass @@ -88,14 +110,18 @@ def get_clip( in seconds (the video always starts at 0 seconds). Args: - start_sec (float): the clip start time in seconds - end_sec (float): the clip end time in seconds - Returns: - video_data_dictonary: A dictionary mapping strings to tensor of the clip's - underlying data. + start_sec (float): The clip start time in seconds. + end_sec (float): The clip end time in seconds. + Returns: + Dict[str, Optional[torch.Tensor]]: A dictionary mapping strings to tensors + of the clip's underlying data. It may include video frames and audio. """ pass def close(self): + """ + Closes any resources associated with the Video object. + """ pass + \ No newline at end of file diff --git a/pytorchvideo/layers/accelerator/mobile_cpu/activation_functions.py b/pytorchvideo/layers/accelerator/mobile_cpu/activation_functions.py index db74384c..f4c79613 100644 --- a/pytorchvideo/layers/accelerator/mobile_cpu/activation_functions.py +++ b/pytorchvideo/layers/accelerator/mobile_cpu/activation_functions.py @@ -15,8 +15,16 @@ class _NaiveSwish(nn.Module): """ - Helper class to implement naive swish for deploy. It is not intended to be used to - build network. + Helper class to implement the naive Swish activation function for deployment. + It is not intended to be used to build networks. + + Swish(x) = x * sigmoid(x) + + Args: + None + + Returns: + torch.Tensor: The output tensor after applying the naive Swish activation. """ def __init__(self): @@ -24,32 +32,76 @@ def __init__(self): self.mul_func = nn.quantized.FloatFunctional() def forward(self, x): + """ + Forward pass through the naive Swish activation function. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: The output tensor after applying the naive Swish activation. + """ return self.mul_func.mul(x, torch.sigmoid(x)) class Swish(EfficientBlockBase): """ - Swish activation function for efficient block. When in original form for training, - using custom op version of swish for better training memory efficiency. When in - deployable form, use naive swish as custom op is not supported to run on Pytorch - Mobile. For better latency on mobile CPU, use HardSwish instead. + Swish activation function for efficient block. + + When in its original form for training, it uses a custom op version of Swish for + better training memory efficiency. When in a deployable form, it uses a naive Swish + as the custom op is not supported for running on PyTorch Mobile. For better latency + on mobile CPU, consider using HardSwish instead. + + Args: + None + + Returns: + torch.Tensor: The output tensor after applying the Swish activation. """ def __init__(self): super().__init__() - self.act = SwishCustomOp() + + # Initialize the activation function based on whether it's for training or deployment + self.act = SwishCustomOp() # Use SwishCustomOp if defined, otherwise use _NaiveSwish def forward(self, x): + """ + Forward pass through the Swish activation function. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: The output tensor after applying the Swish activation. + """ return self.act(x) - def convert(self, *args, **kwarg): + def convert(self, *args, **kwargs): + """ + Convert the activation function to use naive Swish for deployment. + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + None + """ self.act = _NaiveSwish() class HardSwish(EfficientBlockBase): """ - Hardswish activation function. It is natively supported by Pytorch Mobile, and has + Hardswish activation function. It is natively supported by PyTorch Mobile and has better latency than Swish in int8 mode. + + Args: + None + + Returns: + torch.Tensor: The output tensor after applying the HardSwish activation. """ def __init__(self): @@ -57,15 +109,41 @@ def __init__(self): self.act = nn.Hardswish() def forward(self, x): + """ + Forward pass through the HardSwish activation function. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: The output tensor after applying the HardSwish activation. + """ return self.act(x) - def convert(self, *args, **kwarg): + def convert(self, *args, **kwargs): + """ + Placeholder method for converting the activation function. No conversion is + performed for HardSwish. + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + None + """ pass class ReLU(EfficientBlockBase): """ ReLU activation function for EfficientBlockBase. + + Args: + None + + Returns: + torch.Tensor: The output tensor after applying the ReLU activation. """ def __init__(self): @@ -73,15 +151,42 @@ def __init__(self): self.act = nn.ReLU(inplace=True) def forward(self, x): + """ + Forward pass through the ReLU activation function. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: The output tensor after applying the ReLU activation. + """ return self.act(x) - def convert(self, *args, **kwarg): + def convert(self, *args, **kwargs): + """ + Placeholder method for converting the activation function. No conversion is + performed for ReLU. + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + None + """ pass class Identity(EfficientBlockBase): """ - Identity operation for EfficientBlockBase. + Identity operation for EfficientBlockBase. It simply returns the input tensor + unchanged. + + Args: + None + + Returns: + torch.Tensor: The input tensor itself. """ def __init__(self): @@ -89,12 +194,33 @@ def __init__(self): self.act = nn.Identity() def forward(self, x): + """ + Forward pass through the Identity operation. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: The input tensor itself. + """ return self.act(x) - def convert(self, *args, **kwarg): + def convert(self, *args, **kwargs): + """ + Placeholder method for converting the identity operation. No conversion is + performed for Identity. + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + None + """ pass + supported_act_functions = { "relu": ReLU, "swish": Swish, diff --git a/pytorchvideo/layers/accelerator/mobile_cpu/attention.py b/pytorchvideo/layers/accelerator/mobile_cpu/attention.py index 3a6309e4..4f5ec58f 100644 --- a/pytorchvideo/layers/accelerator/mobile_cpu/attention.py +++ b/pytorchvideo/layers/accelerator/mobile_cpu/attention.py @@ -15,7 +15,9 @@ class SqueezeExcitation(EfficientBlockBase): """ - Efficient Squeeze-Excitation (SE). The Squeeze-Excitation block is described in: + Efficient Squeeze-Excitation (SE) block. + + The Squeeze-Excitation block is described in: *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507* This implementation has the same instantiation interface as SE implementation in fvcore, and in original mode for training it is just a wrapped version of SE in @@ -26,6 +28,16 @@ class SqueezeExcitation(EfficientBlockBase): convert_flag variable is to record whether the SqueezeExcitation instance has been converted; SqueezeExcitation is in original form if convert_flag is false, while it is in deployable form if convert_flag is true. + + + Args: + num_channels (int): Number of input channels. + num_channels_reduced (int): + Number of reduced channels. If none, uses reduction_ratio to calculate. + reduction_ratio (float): + How much num_channels should be reduced if num_channels_reduced is not provided. + is_3d (bool): Whether we're operating on 3d data (or 2d), default 2d. + activation (nn.Module): Activation function used, defaults to ReLU. """ def __init__( @@ -36,18 +48,7 @@ def __init__( is_3d: bool = False, activation: Optional[nn.Module] = None, ) -> None: - """ - Args: - num_channels (int): Number of input channels. - num_channels_reduced (int): - Number of reduced channels. If none, uses reduction_ratio to calculate. - reduction_ratio (float): - How much num_channels should be reduced if num_channels_reduced is not provided. - is_3d (bool): Whether we're operating on 3d data (or 2d), default 2d. - activation (nn.Module): Activation function used, defaults to ReLU. - """ super().__init__() - # Implement SE from FVCore here for training. self.se = SqueezeExcitationFVCore( num_channels, num_channels_reduced=num_channels_reduced, @@ -62,14 +63,15 @@ def convert(self, input_blob_size, **kwargs): """ Converts into efficient version of squeeze-excite (SE) for CPU. It changes conv in original SE into linear layer (better supported by CPU). + + Args: + input_blob_size (tuple): Size of the input blob. """ if self.is_3d: avg_pool = nn.AdaptiveAvgPool3d(1) else: avg_pool = nn.AdaptiveAvgPool2d(1) - """ - Reshape tensor size to (B, C) for linear layer. - """ + reshape0 = _Reshape((input_blob_size[0], input_blob_size[1])) fc0 = nn.Linear( self.se.block[0].in_channels, @@ -89,10 +91,7 @@ def convert(self, input_blob_size, **kwargs): state_dict_fc1["weight"] = state_dict_fc1["weight"].squeeze() fc1.load_state_dict(state_dict_fc1) sigmoid = deepcopy(self.se.block[3]) - """ - Output of linear layer has output shape of (B, C). Need to reshape to proper - shape before multiplying with input tensor. - """ + reshape_size_after_sigmoid = (input_blob_size[0], input_blob_size[1], 1, 1) + ( (1,) if self.is_3d else () ) @@ -100,10 +99,19 @@ def convert(self, input_blob_size, **kwargs): se_layers = nn.Sequential( avg_pool, reshape0, fc0, activation, fc1, sigmoid, reshape1 ) - # Add final elementwise multiplication and replace self.se + self.se = _SkipConnectMul(se_layers) self.convert_flag = True def forward(self, x) -> torch.Tensor: + """ + Forward pass through the Squeeze-Excitation (SE) block. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying SE block. + """ out = self.se(x) return out diff --git a/pytorchvideo/layers/accelerator/mobile_cpu/conv_helper.py b/pytorchvideo/layers/accelerator/mobile_cpu/conv_helper.py index 9d9d7c22..3848d390 100644 --- a/pytorchvideo/layers/accelerator/mobile_cpu/conv_helper.py +++ b/pytorchvideo/layers/accelerator/mobile_cpu/conv_helper.py @@ -16,8 +16,9 @@ class _Reshape(nn.Module): """ Helper class to implement data reshape as a module. + Args: - reshape_size (tuple): size of data after reshape. + reshape_size (tuple): Size of data after reshape. """ def __init__( @@ -28,15 +29,25 @@ def __init__( self.reshape_size = reshape_size def forward(self, x): + """ + Forward pass through the reshape operation. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Reshaped tensor. + """ return torch.reshape(x, self.reshape_size) class _SkipConnectMul(nn.Module): """ Helper class to implement skip multiplication. + Args: - layer (nn.Module): layer for skip multiplication. With input x, _SkipConnectMul - implements layer(x)*x. + layer (nn.Module): Layer for skip multiplication. With input x, _SkipConnectMul + implements layer(x) * x. """ def __init__( @@ -48,9 +59,23 @@ def __init__( self.mul_func = nn.quantized.FloatFunctional() def forward(self, x): + """ + Forward pass through the skip multiplication operation. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Result of the skip multiplication. + """ return self.mul_func.mul(x, self.layer(x)) + +import torch.nn as nn +from typing import Tuple +from copy import deepcopy + class _Conv3dTemporalKernel3Decomposed(nn.Module): """ Helper class for decomposing conv3d with temporal kernel of 3 into equivalent conv2ds. @@ -78,11 +103,11 @@ def __init__( """ super().__init__() assert conv3d_in.padding[0] == 1, ( - "_Conv3dTemporalKernel3Eq only support temporal padding of 1, " + "_Conv3dTemporalKernel3Eq only supports temporal padding of 1, " f"but got {conv3d_in.padding[0]}" ) assert conv3d_in.padding_mode == "zeros", ( - "_Conv3dTemporalKernel3Eq only support zero padding, " + "_Conv3dTemporalKernel3Eq only supports zero padding, " f"but got {conv3d_in.padding_mode}" ) self._input_THW_tuple = input_THW_tuple @@ -92,10 +117,11 @@ def __init__( kernel_size = conv3d_in.kernel_size[1:] groups = conv3d_in.groups stride_2d = conv3d_in.stride[1:] - # Create 3 conv2d to emulate conv3d. + + # Create 3 conv2d layers to emulate conv3d. if ( self._input_THW_tuple[0] > 1 - ): # Those two conv2d are needed only when temporal input > 1. + ): # These two conv2d layers are needed only when temporal input > 1. self._conv2d_3_3_0 = nn.Conv2d( in_channels, out_channels, @@ -155,13 +181,13 @@ def __init__( def forward(self, x): """ - Use three conv2d to emulate conv3d. - This forward assumes zero padding of size 1 in temporal dimension. + Use three conv2d layers to emulate conv3d. + This forward assumes zero padding of size 1 in the temporal dimension. """ if self._input_THW_tuple[0] > 1: out_tensor_list = [] """ - First output plane in temporal dimension, + First output plane in the temporal dimension, conv2d_3_3_0 is skipped due to zero padding. """ cur_tensor = ( @@ -184,7 +210,7 @@ def forward(self, x): ) out_tensor_list.append(cur_tensor) """ - Last output plane in temporal domain, conv2d_3_3_2 is skipped due to zero padding. + Last output plane in the temporal domain, conv2d_3_3_2 is skipped due to zero padding. """ cur_tensor = ( self._add_funcs[-1] @@ -193,7 +219,7 @@ def forward(self, x): ) out_tensor_list.append(cur_tensor) return self._cat_func.cat(out_tensor_list, 2) - else: # Degenerated to simple conv2d + else: # Degenerated to a simple conv2d return self._conv2d_3_3_1(x[:, :, 0]).unsqueeze(2) diff --git a/pytorchvideo/layers/accelerator/mobile_cpu/convolutions.py b/pytorchvideo/layers/accelerator/mobile_cpu/convolutions.py index e1e29b07..949ed2d9 100644 --- a/pytorchvideo/layers/accelerator/mobile_cpu/convolutions.py +++ b/pytorchvideo/layers/accelerator/mobile_cpu/convolutions.py @@ -110,12 +110,13 @@ def convert( For quantized operation on new version of QNNPACK with native int8 Conv3d, this function will only apply operator fusion. + Args: - input_blob_size (tuple): blob size at the input of Conv3dPwBnAct instance. - convert_for_quantize (bool): whether this module is intended to be quantized. - native_conv3d_op_qnnpack (bool): whether the QNNPACK version has native int8 + input_blob_size (tuple): Blob size at the input of Conv3dPwBnAct instance. + convert_for_quantize (bool): Whether this module is intended to be quantized. + native_conv3d_op_qnnpack (bool): Whether the QNNPACK version has native int8 Conv3d. - kwargs (any): any extra keyword arguments from upstream unused by convert(). + kwargs (any): Any extra keyword arguments from upstream unused by convert(). """ assert ( self.convert_flag is False @@ -175,6 +176,15 @@ def convert( self.convert_flag = True def forward(self, x): + """ + Forward pass of the Conv3dPwBnAct module. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ x = self.kernel(x) return x @@ -390,6 +400,7 @@ def convert( This conversion is done by first fuse conv3d with bn, convert conv3d into equivalent conv2d, and optionally fuse conv2d with relu. + Args: input_blob_size (tuple): blob size at the input of Conv3dTemporalKernel1BnAct instance during forward. diff --git a/pytorchvideo/layers/accelerator/mobile_cpu/fully_connected.py b/pytorchvideo/layers/accelerator/mobile_cpu/fully_connected.py index 83421d4f..59d04f59 100644 --- a/pytorchvideo/layers/accelerator/mobile_cpu/fully_connected.py +++ b/pytorchvideo/layers/accelerator/mobile_cpu/fully_connected.py @@ -10,6 +10,7 @@ class FullyConnected(NoOpConvertBlock): """ Implements fully connected layer. This operator is natively supported by QNNPACK for mobile CPU with good efficiency, and no change is made upon convert(). + Args: in_features (int): input channels for FC layer. out_features (int): output channels for FC layer. diff --git a/pytorchvideo/layers/accelerator/mobile_cpu/pool.py b/pytorchvideo/layers/accelerator/mobile_cpu/pool.py index 1e92ee9e..47e95fb2 100644 --- a/pytorchvideo/layers/accelerator/mobile_cpu/pool.py +++ b/pytorchvideo/layers/accelerator/mobile_cpu/pool.py @@ -18,18 +18,22 @@ class AdaptiveAvgPool3dOutSize1(EfficientBlockBase): """ def __init__(self): + """ + Initializes an AdaptiveAvgPool3dOutSize1 layer. + """ super().__init__() self.pool = nn.AdaptiveAvgPool3d(1) self.convert_flag = False def convert(self, input_blob_size: Tuple, **kwargs): """ - Converts AdaptiveAvgPool into AvgPool with constant kernel size for better + Converts AdaptiveAvgPool into AvgPool with a constant kernel size for better efficiency. + Args: - input_blob_size (tuple): blob size at the input of - AdaptiveAvgPool3dOutSize1 instance during forward. - kwargs (any): any keyword argument (unused). + input_blob_size (tuple): Blob size at the input of the AdaptiveAvgPool3dOutSize1 + instance during forward. + kwargs (any): Any keyword arguments (unused). """ assert ( self.convert_flag is False @@ -39,6 +43,15 @@ def convert(self, input_blob_size: Tuple, **kwargs): self.convert_flag = True def forward(self, x): + """ + Forward pass through the layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the AdaptiveAvgPool3dOutSize1 operation. + """ return self.pool(x) @@ -46,11 +59,18 @@ class AdaptiveAvgPool2dOutSize1(EfficientBlockBase): """ Implements AdaptiveAvgPool2d with output (H, W) = (1, 1). This operator has better efficiency than AdaptiveAvgPool for mobile CPU. + + Attributes: + pool (nn.AdaptiveAvgPool2d): The AdaptiveAvgPool2d layer with output size (1, 1). + convert_flag (bool): Flag indicating whether the conversion has been applied. """ def __init__( self, ): + """ + Initializes an AdaptiveAvgPool2dOutSize1 instance. + """ super().__init__() self.pool = nn.AdaptiveAvgPool2d(1) self.convert_flag = False @@ -59,10 +79,14 @@ def convert(self, input_blob_size: Tuple, **kwargs): """ Converts AdaptiveAvgPool into AvgPool with constant kernel size for better efficiency. + Args: - input_blob_size (tuple): blob size at the input of - AdaptiveAvgPool2dOutSize1 instance during forward. - kwargs (any): any keyword argument (unused). + input_blob_size (tuple): Blob size at the input of AdaptiveAvgPool2dOutSize1 + instance during forward. + kwargs (any): Any keyword argument (unused). + + Raises: + AssertionError: If conversion is attempted after it has already been applied. """ assert ( self.convert_flag is False @@ -72,17 +96,27 @@ def convert(self, input_blob_size: Tuple, **kwargs): self.convert_flag = True def forward(self, x): + """ + Forward pass through the layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after pooling. + """ return self.pool(x) class AdaptiveAvgPool3d(NoOpConvertBlock): """ Implements AdaptiveAvgPool3d with any output (T, H, W) size. This operator is - supported by QNNPACK for mobile CPU with resonable efficiency, and no change is + supported by QNNPACK for mobile CPU with reasonable efficiency, and no change is made upon convert(). If the output (T, H, W) = (1, 1, 1), use AdaptiveAvgPool3dOutSize1 for better efficiency. + Args: - output_size (int or tuple): when it is a tuple, the output (T, H, W) of pool + output_size (int or tuple): When it is a tuple, the output (T, H, W) of the pool will be equal to output_size. When it is an int, the output (T, H, W) will be equal to (output_size, output_size, output_size). """ @@ -91,17 +125,24 @@ def __init__( self, output_size: Union[int, Tuple], ): + """ + Initializes an AdaptiveAvgPool3d instance. + + Args: + output_size (int or tuple): Desired output size for the pooling operation. + """ super().__init__(model=nn.AdaptiveAvgPool3d(output_size)) class AdaptiveAvgPool2d(NoOpConvertBlock): """ Implements AdaptiveAvgPool2d with any output (H, W) size. This operator is - supported by QNNPACK for mobile CPU with resonable efficiency, and no change is + supported by QNNPACK for mobile CPU with reasonable efficiency, and no change is made upon convert(). If the output (H, W) = (1, 1), use AdaptiveAvgPool2dOutSize1 for better efficiency. + Args: - output_size (int or tuple): when it is a tuple, the output (H, W) of pool + output_size (int or tuple): When it is a tuple, the output (H, W) of the pool will be equal to output_size. When it is an int, the output (H, W) will be equal to (output_size, output_size). """ @@ -110,4 +151,10 @@ def __init__( self, output_size: Union[int, Tuple], ): + """ + Initializes an AdaptiveAvgPool2d instance. + + Args: + output_size (int or tuple): Desired output size for the pooling operation. + """ super().__init__(model=nn.AdaptiveAvgPool2d(output_size)) diff --git a/pytorchvideo/layers/attention.py b/pytorchvideo/layers/attention.py index 1b9f65e2..bafd67a3 100644 --- a/pytorchvideo/layers/attention.py +++ b/pytorchvideo/layers/attention.py @@ -17,6 +17,24 @@ @torch.fx.wrap def _unsqueeze_dims_fx(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: + """ + Unsqueezes dimensions of a 3D tensor to make it 4D if needed. + + Args: + tensor (torch.Tensor): The input tensor. + + Returns: + Tuple[torch.Tensor, int]: A tuple containing the modified tensor and its dimension. + + If the input tensor has 3 dimensions, it adds a new dimension at the second position to make + it 4D. If the tensor already has 4 dimensions, it does nothing. + + Example: + ``` + input_tensor = torch.randn(32, 3, 64, 64) + modified_tensor, new_dim = _unsqueeze_dims_fx(input_tensor) + ``` + """ tensor_dim = tensor.ndim if tensor_dim == 4: pass @@ -29,11 +47,39 @@ def _unsqueeze_dims_fx(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: @torch.jit.script def _unsqueeze_dims_jit(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: + """ + JIT script version of _unsqueeze_dims_fx. + + Args: + tensor (torch.Tensor): The input tensor. + + Returns: + Tuple[torch.Tensor, int]: A tuple containing the modified tensor and its dimension. + """ return _unsqueeze_dims_fx(tensor) @torch.fx.wrap def _squeeze_dims_fx(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor: + """ + Squeezes dimensions of a 4D tensor to make it 3D if needed. + + Args: + tensor (torch.Tensor): The input tensor. + tensor_dim (int): The original dimension of the tensor. + + Returns: + torch.Tensor: The modified tensor. + + If the input tensor has 4 dimensions and `tensor_dim` is 3, it removes the second dimension + to make it 3D. If the tensor already has 3 dimensions and `tensor_dim` is 3, it does nothing. + + Example: + ``` + input_tensor = torch.randn(32, 1, 64, 64) + modified_tensor = _squeeze_dims_fx(input_tensor, 3) + ``` + """ if tensor_dim == 4: pass elif tensor_dim == 3: @@ -45,6 +91,16 @@ def _squeeze_dims_fx(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor: @torch.jit.script def _squeeze_dims_jit(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor: + """ + JIT script version of _squeeze_dims_fx. + + Args: + tensor (torch.Tensor): The input tensor. + tensor_dim (int): The original dimension of the tensor. + + Returns: + torch.Tensor: The modified tensor. + """ return _squeeze_dims_fx(tensor, tensor_dim) @@ -64,6 +120,20 @@ class Mlp(nn.Module): Linear (hidden_features, out_features) ↓ Dropout (p=dropout_rate) + + Args: + in_features (int): Input feature dimension. + hidden_features (Optional[int]): Hidden feature dimension (default is input dimension). + out_features (Optional[int]): Output feature dimension (default is input dimension). + act_layer (Callable): Activation layer applied after the first linear layer. + dropout_rate (float): Dropout rate after each linear layer (0.0 by default). + bias_on (bool): Whether to use biases for linear layers (True by default). + + Example: + ``` + mlp_block = Mlp(in_features=256, hidden_features=512, dropout_rate=0.1) + output = mlp_block(input_tensor) + ``` """ def __init__( @@ -135,13 +205,25 @@ def __init__( Norm - Params: - pool (Optional[Callable]): Pool operation that is applied to the input tensor. - If pool is none, return the input tensor. - has_cls_embed (bool): Whether the input tensor contains cls token. Pool - operation excludes cls token. - norm: (Optional[Callable]): Optional normalization operation applied to - tensor after pool. + Args: + pool (Optional[torch.nn.Module]): Pooling operation applied to the input tensor. + If None, no pooling is applied. + has_cls_embed (bool): Indicates whether the input tensor contains a cls token. + The pooling operation excludes the cls token if present. + norm (Optional[torch.nn.Module]): Optional normalization operation applied to + the tensor after pooling. + + This class applies a specified pooling operation to a flattened input tensor, preserving the + spatial structure. If the input tensor contains a cls token, the pooling operation excludes it. + An optional normalization operation can be applied after pooling. + + Example: + ``` + pool_layer = _AttentionPool(pool=torch.nn.MaxPool3d(kernel_size=(2, 2, 2)), + has_cls_embed=True, + norm=torch.nn.LayerNorm((16, 16, 16))) + output_tensor, output_shape = pool_layer(input_tensor, [32, 16, 16]) + ``` """ super().__init__() self.has_pool = pool is not None @@ -163,13 +245,19 @@ def forward( self, tensor: torch.Tensor, thw_shape: List[int] ) -> Tuple[torch.Tensor, List[int]]: """ + Applies the specified pooling operation to the input tensor while preserving spatial structure. + Args: tensor (torch.Tensor): Input tensor. - thw_shape (List): The shape of the input tensor (before flattening). + thw_shape (List[int]): The shape of the input tensor (before flattening). Returns: - tensor (torch.Tensor): Input tensor after pool. - thw_shape (List[int]): Output tensor shape (before flattening). + torch.Tensor: Output tensor after pooling. + List[int]: Output tensor shape (before flattening). + + This method reshapes the input tensor, applies the pooling operation, and restores the + original shape. If normalization is used, it can be applied before or after pooling + based on the configuration. """ if not self.has_pool: return tensor, thw_shape @@ -433,6 +521,25 @@ def _qkv_proj( batch_size: int, chan_size: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Project the query (q), key (k), and value (v) tensors. + + Args: + q (torch.Tensor): Query tensor. + q_size (int): Query size. + k (torch.Tensor): Key tensor. + k_size (int): Key size. + v (torch.Tensor): Value tensor. + v_size (int): Value size. + batch_size (int): Batch size. + chan_size (int): Channel size. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Projected query, key, and value tensors. + + This method applies linear projections to the query, key, and value tensors and reshapes them + as needed for subsequent attention computations. + """ q = ( self.q(q) .reshape(batch_size, q_size, self.num_heads, chan_size // self.num_heads) @@ -459,6 +566,23 @@ def _qkv_pool( ) -> Tuple[ torch.Tensor, List[int], torch.Tensor, List[int], torch.Tensor, List[int] ]: + """ + Apply pooling to query (q), key (k), and value (v) tensors. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + v (torch.Tensor): Value tensor. + thw_shape (List[int]): The shape of the input tensor (before flattening). + + Returns: + Tuple[torch.Tensor, List[int], torch.Tensor, List[int], torch.Tensor, List[int]]: + Processed query, key, and value tensors along with their respective shapes. + + This method applies attention pooling to the query, key, and value tensors and returns + the processed tensors along with their shapes. + """ + q, q_shape = self._attention_pool_q(q, thw_shape) k, k_shape = self._attention_pool_k(k, thw_shape) v, v_shape = self._attention_pool_v(v, thw_shape) @@ -470,13 +594,38 @@ def _get_qkv_length( k_shape: List[int], v_shape: List[int], ) -> Tuple[int, int, int]: + """ + Calculate the lengths of query (q), key (k), and value (v) tensors. + + Args: + q_shape (List[int]): Shape of the query tensor. + k_shape (List[int]): Shape of the key tensor. + v_shape (List[int]): Shape of the value tensor. + + Returns: + Tuple[int, int, int]: Lengths of query, key, and value tensors. + + This method calculates the lengths of query, key, and value tensors, taking into account + whether the input tensor contains a cls token. + """ q_N = self._prod(q_shape) + 1 if self.has_cls_embed else self._prod(q_shape) k_N = self._prod(k_shape) + 1 if self.has_cls_embed else self._prod(k_shape) v_N = self._prod(v_shape) + 1 if self.has_cls_embed else self._prod(v_shape) return q_N, k_N, v_N def _prod(self, shape: List[int]) -> int: - """Torchscriptable version of `numpy.prod`. Note that `_prod([]) == 1`""" + """ + Torchscriptable version of `numpy.prod`. Note that `_prod([]) == 1` + + Args: + shape (List[int]): List of dimensions. + + Returns: + int: Product of the dimensions in the shape list. + + This method calculates the product of dimensions in the input list, equivalent to + `numpy.prod`. + """ p: int = 1 for dim in shape: p *= dim @@ -493,6 +642,25 @@ def _reshape_qkv_to_seq( B: int, C: int, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Reshape and transpose the query (q), key (k), and value (v) tensors. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + v (torch.Tensor): Value tensor. + q_N (int): Length of query tensor. + v_N (int): Length of value tensor. + k_N (int): Length of key tensor. + B (int): Batch size. + C (int): Channel size. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Reshaped and transposed query, key, and value tensors. + + This method reshapes and transposes the query, key, and value tensors for further computation + in the attention mechanism. + """ q = q.permute(0, 2, 1, 3).reshape(B, q_N, C) v = v.permute(0, 2, 1, 3).reshape(B, v_N, C) k = k.permute(0, 2, 1, 3).reshape(B, k_N, C) @@ -502,9 +670,17 @@ def forward( self, x: torch.Tensor, thw_shape: List[int] ) -> Tuple[torch.Tensor, List[int]]: """ + Forward pass through the MultiScaleAttention block. + Args: x (torch.Tensor): Input tensor. thw_shape (List): The shape of the input tensor (before flattening). + + Returns: + Tuple[torch.Tensor, List[int]]: Output tensor and updated shape. + + This method computes the forward pass through the MultiScaleAttention block, including + projections, pooling, attention, and transformations. """ B, N, C = x.shape @@ -553,6 +729,21 @@ def _load_from_state_dict( unexpected_keys, error_msgs, ): + """ + Load parameters from a state dictionary with support for backward compatibility. + + Args: + state_dict: State dictionary. + prefix: Prefix for the keys in the state dictionary. + local_metadata: Local metadata. + strict: Whether to enforce strict loading. + missing_keys: List to store missing keys. + unexpected_keys: List to store unexpected keys. + error_msgs: List to store error messages. + + This method is used to load parameters from a state dictionary with support for backward + compatibility by renaming keys as needed. + """ version = local_metadata.get("version", None) if version is None or version < 2: @@ -577,8 +768,7 @@ def _load_from_state_dict( class MultiScaleBlock(nn.Module): """ - Implementation of a multiscale vision transformer block. Each block contains a - multiscale attention layer and a Mlp layer. + Multi-Scale Vision Transformer block with Multi-Scale Attention and MLP layers. :: @@ -605,6 +795,42 @@ class MultiScaleBlock(nn.Module): DropPath | ↓ | Summation ←------------+ + + Args: + dim (int): Input feature dimension. + dim_out (int): Output feature dimension. + num_heads (int): Number of heads in the attention layer. + mlp_ratio (float): MLP ratio controlling the hidden layer dimension. + qkv_bias (bool): Whether to use bias in the QKV projection. + dropout_rate (float): Dropout rate (0.0 by default, disabled). + droppath_rate (float): DropPath rate (0.0 by default, disabled). + act_layer (nn.Module): Activation layer used in the MLP block. + norm_layer (nn.Module): Normalization layer. + attn_norm_layer (nn.Module): Normalization layer in the attention module. + dim_mul_in_att (bool): If True, dimension expansion occurs inside the attention module, + otherwise it occurs in the MLP block. + kernel_q (_size_3_t): Pooling kernel size for q (1, 1, 1 by default). + kernel_kv (_size_3_t): Pooling kernel size for kv (1, 1, 1 by default). + stride_q (_size_3_t): Pooling kernel stride for q (1, 1, 1 by default). + stride_kv (_size_3_t): Pooling kernel stride for kv (1, 1, 1 by default). + pool_mode (str): Pooling mode ("conv" by default, can be "avg", or "max"). + has_cls_embed (bool): Whether the input tensor contains a cls token. + pool_first (bool): If True, apply pooling before qkv projection. + residual_pool (bool): If True, use pooling with Improved Multiscale Vision Transformer's + pooling residual connection. + depthwise_conv (bool): Whether to use depthwise or full convolution for pooling. + bias_on (bool): Whether to use biases for linear layers. + separate_qkv (bool): Whether to use separate layers for qkv projections. + + This class represents a Multi-Scale Vision Transformer block, which consists of a Multi-Scale + Attention layer and an MLP layer. The block can perform dimension expansion either inside the + attention module or the MLP block based on the `dim_mul_in_att` parameter. + + Example: + ``` + multi_scale_block = MultiScaleBlock(dim=256, dim_out=512, num_heads=8) + output_tensor, output_shape = multi_scale_block(input_tensor, [32, 16, 16]) + ``` """ def __init__( @@ -730,9 +956,19 @@ def forward( self, x: torch.Tensor, thw_shape: List[int] ) -> Tuple[torch.Tensor, List[int]]: """ + Forward pass of the MultiScaleBlock. + Args: x (torch.Tensor): Input tensor. thw_shape (List): The shape of the input tensor (before flattening). + + Returns: + torch.Tensor: Output tensor. + List[int]: Output tensor shape (before flattening). + + This method processes the input tensor through the Multi-Scale Attention and MLP layers, + handling dimension expansion based on the configuration. It can also apply pooling if + specified. """ x_norm = ( diff --git a/pytorchvideo/layers/attention_torchscript.py b/pytorchvideo/layers/attention_torchscript.py index 8bfb080f..4529f936 100644 --- a/pytorchvideo/layers/attention_torchscript.py +++ b/pytorchvideo/layers/attention_torchscript.py @@ -13,8 +13,8 @@ class Mlp(nn.Module): """ - A MLP block that contains two linear layers with a normalization layer. The MLP - block is used in a transformer model after the attention block. + A Multilayer Perceptron (MLP) block that contains two linear layers with a normalization layer. + The MLP block is commonly used in transformer models after the attention block. :: @@ -27,6 +27,16 @@ class Mlp(nn.Module): Linear (hidden_features, out_features) ↓ Dropout (p=dropout_rate) + + Args: + in_features (int): Input feature dimension. + hidden_features (Optional[int]): Hidden feature dimension. Defaults to None, which sets it + equal to the input feature dimension. + out_features (Optional[int]): Output feature dimension. Defaults to None, which sets it + equal to the input feature dimension. + act_layer (Callable): Activation layer used after the first linear layer. Defaults to nn.GELU. + dropout_rate (float): Dropout rate after each linear layer. Defaults to 0.0 (no dropout). + bias_on (bool): Whether to include bias terms in linear layers. Defaults to True. """ def __init__( @@ -39,15 +49,7 @@ def __init__( bias_on: bool = True, ) -> None: """ - Args: - in_features (int): Input feature dimension. - hidden_features (Optional[int]): Hidden feature dimension. By default, - hidden feature is set to input feature dimension. - out_features (Optional[int]): Output feature dimension. By default, output - features dimension is set to input feature dimension. - act_layer (Callable): Activation layer used after the first linear layer. - dropout_rate (float): Dropout rate after each linear layer. Dropout is not used - by default. + Initializes an instance of the Mlp class. """ super().__init__() self.dropout_rate = dropout_rate @@ -60,8 +62,13 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """ + Performs a forward pass through the MLP block. + Args: x (tensor): Input tensor. + + Returns: + tensor: Output tensor. """ x = self.fc1(x) x = self.act(x) @@ -71,6 +78,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @torch.fx.wrap def _unsqueeze_dims_fx(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: + """ + Unsqueezes dimensions of a tensor if needed to ensure it has a dimension of 4. + + Args: + tensor (torch.Tensor): Input tensor. + + Returns: + Tuple[torch.Tensor, int]: A tuple containing the modified tensor and its original dimension. + + Raises: + NotImplementedError: If the input tensor dimension is not supported. + """ tensor_dim = tensor.ndim if tensor_dim == 4: pass @@ -83,11 +102,37 @@ def _unsqueeze_dims_fx(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: @torch.jit.script def _unsqueeze_dims_jit(tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: + """ + Unsqueezes dimensions of a tensor if needed to ensure it has a dimension of 4. + (JIT script version of _unsqueeze_dims_fx) + + Args: + tensor (torch.Tensor): Input tensor. + + Returns: + Tuple[torch.Tensor, int]: A tuple containing the modified tensor and its original dimension. + + Raises: + NotImplementedError: If the input tensor dimension is not supported. + """ return _unsqueeze_dims_fx(tensor) @torch.fx.wrap def _squeeze_dims_fx(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor: + """ + Squeezes dimensions of a tensor if needed to ensure it has a dimension of 4. + + Args: + tensor (torch.Tensor): Input tensor. + tensor_dim (int): Original dimension of the tensor. + + Returns: + torch.Tensor: The modified tensor. + + Raises: + NotImplementedError: If the input tensor dimension is not supported. + """ if tensor_dim == 4: pass elif tensor_dim == 3: @@ -99,6 +144,20 @@ def _squeeze_dims_fx(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor: @torch.jit.script def _squeeze_dims_jit(tensor: torch.Tensor, tensor_dim: int) -> torch.Tensor: + """ + Squeezes dimensions of a tensor if needed to ensure it has a dimension of 4. + (JIT script version of _squeeze_dims_fx) + + Args: + tensor (torch.Tensor): Input tensor. + tensor_dim (int): Original dimension of the tensor. + + Returns: + torch.Tensor: The modified tensor. + + Raises: + NotImplementedError: If the input tensor dimension is not supported. + """ return _squeeze_dims_fx(tensor, tensor_dim) @@ -108,6 +167,8 @@ def _pre_attention_pool( ) -> Tuple[torch.Tensor, Tuple[int, int, int, int, int, int, int, int]]: """ Apply pool to a flattened input (given pool operation and the unflattened shape). + Pre-processes the input tensor for attention pooling by reshaping and optionally + applying normalization. Input @@ -150,6 +211,17 @@ def _post_attention_pool( tensor: torch.Tensor, thw_shape: List[int], ) -> Tuple[torch.Tensor, List[int]]: + """ + Post-processes the input tensor after attention pooling. + + Args: + tensor (torch.Tensor): Input tensor. + thw_shape (List[int]): The original shape of the input tensor (before flattening). + + Returns: + torch.Tensor: Post-processed input tensor. + List[int]: Output tensor shape (after flattening). + """ B, N, L, C, T, H, W, tensor_dim = thw_shape thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]] @@ -342,6 +414,22 @@ def _qkv_proj( batch_size: List[int], chan_size: List[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Applies projection to q, k, and v for multi-head attention. + + Args: + q (torch.Tensor): Query tensor. + q_size (List[int]): Size of the query tensor. + k (torch.Tensor): Key tensor. + k_size (List[int]): Size of the key tensor. + v (torch.Tensor): Value tensor. + v_size (List[int]): Size of the value tensor. + batch_size (List[int]): List containing batch size information. + chan_size (List[int]): List containing channel size information. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Projected q, k, and v tensors. + """ q = ( self.q(q) .reshape(batch_size, q_size, self.num_heads, chan_size // self.num_heads) @@ -368,6 +456,19 @@ def _qkv_pool( ) -> Tuple[ torch.Tensor, List[int], torch.Tensor, List[int], torch.Tensor, List[int] ]: + """ + Applies pooling to q, k, and v tensors and optionally applies gelu activation. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + v (torch.Tensor): Value tensor. + thw_shape (List[int]): The shape of the input tensor (before flattening). + + Returns: + Tuple[torch.Tensor, List[int], torch.Tensor, List[int], torch.Tensor, List[int]]: + Pooled and reshaped q, k, and v tensors along with their respective shapes. + """ if self.pool_q is None: q_shape = thw_shape else: @@ -415,6 +516,17 @@ def _get_qkv_length( k_shape: List[int], v_shape: List[int], ) -> Tuple[int]: + """ + Calculates the lengths of q, k, and v tensors. + + Args: + q_shape (List[int]): Shape of the q tensor. + k_shape (List[int]): Shape of the k tensor. + v_shape (List[int]): Shape of the v tensor. + + Returns: + Tuple[int]: Lengths of q, k, and v tensors. + """ q_N = numpy.prod(q_shape) + 1 if self.has_cls_embed else numpy.prod(q_shape) k_N = numpy.prod(k_shape) + 1 if self.has_cls_embed else numpy.prod(k_shape) v_N = numpy.prod(v_shape) + 1 if self.has_cls_embed else numpy.prod(v_shape) @@ -431,6 +543,22 @@ def _reshape_qkv_to_seq( B: int, C: int, ) -> Tuple[int]: + """ + Reshapes q, k, and v tensors for sequence processing. + + Args: + q (torch.Tensor): Query tensor. + k (torch.Tensor): Key tensor. + v (torch.Tensor): Value tensor. + q_N (int): Length of the q tensor. + v_N (int): Length of the v tensor. + k_N (int): Length of the k tensor. + B (int): Batch size. + C (int): Channel size. + + Returns: + Tuple[int]: Reshaped q, k, and v tensors. + """ q = q.permute(0, 2, 1, 3).reshape(B, q_N, C) v = v.permute(0, 2, 1, 3).reshape(B, v_N, C) k = k.permute(0, 2, 1, 3).reshape(B, k_N, C) @@ -440,9 +568,14 @@ def forward( self, x: torch.Tensor, thw_shape: List[int] ) -> Tuple[torch.Tensor, List[int]]: """ + Forward pass through the YourAttentionLayer. + Args: x (torch.Tensor): Input tensor. thw_shape (List): The shape of the input tensor (before flattening). + + Returns: + Tuple[torch.Tensor, List[int]]: Output tensor and the shape information. """ B, N, C = x.shape @@ -493,6 +626,46 @@ class ScriptableMultiScaleBlock(nn.Module): DropPath | ↓ | Summation ←------------+ + + Args: + dim (int): Input feature dimension. + dim_out (int): Output feature dimension. + num_heads (int): Number of heads in the attention layer. + mlp_ratio (float): Mlp ratio controlling the feature dimension in the hidden layer + of the Mlp block. + qkv_bias (bool): If False, the qkv layer will not learn an additive bias. Default: False. + dropout_rate (float): DropOut rate. If 0, DropOut is disabled. + droppath_rate (float): DropPath rate. If 0, DropPath is disabled. + act_layer (nn.Module): Activation layer used in the Mlp layer. + norm_layer (nn.Module): Normalization layer. + attn_norm_layer (nn.Module): Normalization layer in the attention module. + kernel_q (_size_3_t): Pooling kernel size for q. If all dimensions have size 1, + pooling is not applied (by default). + kernel_kv (_size_3_t): Pooling kernel size for kv. If all dimensions have size 1, + pooling is disabled (by default). + stride_q (_size_3_t): Pooling kernel stride for q. + stride_kv (_size_3_t): Pooling kernel stride for kv. + pool_mode (str): Pooling mode, including "conv" (learned pooling), "avg" + (average pooling), and "max" (max pooling). + has_cls_embed (bool): If True, the first token of the input tensor is a cls token. + Otherwise, the input tensor doesn't contain a cls token, and pooling isn't applied + to the cls token. + pool_first (bool): If True, pooling is applied before qkv projection. Default: False. + residual_pool (bool): If True, use Improved Multiscale Vision Transformer's pooling + residual connection. + depthwise_conv (bool): Whether to use depthwise or full convolution for pooling. + bias_on (bool): Whether to use biases for linear layers. + separate_qkv (bool): Whether to use separate or one layer for qkv projections. + + This class represents a multiscale vision transformer block, which combines multiscale + attention and MLP layers. It takes an input tensor `x` and its shape `thw_shape`, applies + the multiscale attention, and returns the processed tensor along with the updated shape. + + Example usage: + ``` + block = ScriptableMultiScaleBlock(dim=256, dim_out=512, num_heads=8) + output, updated_shape = block(input_tensor, [16, 32, 32]) + ``` """ def __init__( diff --git a/pytorchvideo/layers/batch_norm.py b/pytorchvideo/layers/batch_norm.py index 9687e4ba..057c7e8a 100644 --- a/pytorchvideo/layers/batch_norm.py +++ b/pytorchvideo/layers/batch_norm.py @@ -9,13 +9,30 @@ class NaiveSyncBatchNorm1d(nn.BatchNorm1d): """ - An implementation of 1D naive sync batch normalization. See details in - NaiveSyncBatchNorm2d below. + 1D Naive Sync Batch Normalization for PyTorch. + + This is an implementation of 1D batch normalization that supports synchronization + across multiple devices (local or global). It extends the functionality of + PyTorch's `nn.BatchNorm1d`. Args: - num_sync_devices (int): number of (local) devices to sync. - global_sync (bool): sync across all devices (on all machines). - args (list): other arguments. + num_sync_devices (int): Number of local devices to sync with. If global_sync is True, + this parameter is ignored. + global_sync (bool): If True, syncs across all devices on all machines. + **args: Additional arguments to be passed to the base `nn.BatchNorm1d` constructor. + + Raises: + ValueError: If conflicting parameters are provided (e.g., both global_sync and num_sync_devices). + + Note: + To use this synchronization, make sure to initialize the distributed environment + (e.g., using PyTorch's `torch.distributed.init_process_group`). + + Example: + ``` + sync_bn = NaiveSyncBatchNorm1d(num_sync_devices=4, global_sync=False, num_features=64) + output = sync_bn(input_tensor) + ``` """ def __init__(self, num_sync_devices=None, global_sync=True, **args): @@ -80,7 +97,11 @@ def forward(self, input): class NaiveSyncBatchNorm2d(nn.BatchNorm2d): """ - An implementation of 2D naive sync batch normalization. + 2D Naive Sync Batch Normalization for PyTorch. + + This is an implementation of 2D batch normalization that supports synchronization + across multiple devices (local or global). It serves as a correct alternative to + `nn.SyncBatchNorm` when there are varying batch sizes on different workers. In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient when the batch size on each worker is different. (e.g., when scale augmentation is used, or when it is applied to mask head). @@ -88,15 +109,22 @@ class NaiveSyncBatchNorm2d(nn.BatchNorm2d): This is a slower but correct alternative to `nn.SyncBatchNorm`. Args: - num_sync_devices (int): number of (local) devices to sync. - global_sync (bool): sync across all devices (on all machines). - args (list): other arguments. + num_sync_devices (int): Number of local devices to sync with. If global_sync is True, + this parameter is ignored. + global_sync (bool): If True, syncs across all devices on all machines. + **args: Additional arguments to be passed to the base `nn.BatchNorm2d` constructor. Note: - This module computes overall statistics by using - statistics of each worker with equal weight. The result is true statistics - of all samples (as if they are all on one worker) only when all workers - have the same (N, H, W). This mode does not support inputs with zero batch size. + This module computes overall statistics by using statistics of each worker with equal weight. + The result represents true statistics of all samples as if they are all on one worker, + provided that all workers have the same input dimensions (N, H, W). This mode does not support + inputs with zero batch size. + + Example: + ``` + sync_bn = NaiveSyncBatchNorm2d(num_sync_devices=4, global_sync=False, num_features=64) + output = sync_bn(input_tensor) + ``` """ def __init__(self, num_sync_devices=None, global_sync=True, **args): @@ -125,6 +153,15 @@ def __init__(self, num_sync_devices=None, global_sync=True, **args): super(NaiveSyncBatchNorm2d, self).__init__(**args) def forward(self, input): + """ + Forward pass through the NaiveSyncBatchNorm2d layer. + + Args: + input (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized and scaled output tensor. + """ if du.get_world_size() == 1 or not self.training: return super().forward(input) @@ -161,12 +198,29 @@ def forward(self, input): class NaiveSyncBatchNorm3d(nn.BatchNorm3d): """ - Naive version of Synchronized 3D BatchNorm. See details in - NaiveSyncBatchNorm2d above. + 3D Naive Sync Batch Normalization for PyTorch. + + This is an implementation of 3D batch normalization that supports synchronization + across multiple devices (local or global). It serves as a correct alternative to + `nn.SyncBatchNorm` when there are varying batch sizes on different workers. + Args: - num_sync_devices (int): number of (local) devices to sync. - global_sync (bool): sync across all devices (on all machines). - args (list): other arguments. + num_sync_devices (int): Number of local devices to sync with. If global_sync is True, + this parameter is ignored. + global_sync (bool): If True, syncs across all devices on all machines. + **args: Additional arguments to be passed to the base `nn.BatchNorm3d` constructor. + + Note: + This module computes overall statistics by using statistics of each worker with equal weight. + The result represents true statistics of all samples as if they are all on one worker, + provided that all workers have the same input dimensions (N, D, H, W). This mode does not + support inputs with zero batch size. + + Example: + ``` + sync_bn = NaiveSyncBatchNorm3d(num_sync_devices=4, global_sync=False, num_features=64) + output = sync_bn(input_tensor) + ``` """ def __init__(self, num_sync_devices=None, global_sync=True, **args): @@ -195,6 +249,15 @@ def __init__(self, num_sync_devices=None, global_sync=True, **args): super(NaiveSyncBatchNorm3d, self).__init__(**args) def forward(self, input): + """ + Forward pass through the NaiveSyncBatchNorm3d layer. + + Args: + input (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Normalized and scaled output tensor. + """ if du.get_world_size() == 1 or not self.training: return super().forward(input) diff --git a/pytorchvideo/layers/convolutions.py b/pytorchvideo/layers/convolutions.py index 35d2ebc9..7e156d13 100644 --- a/pytorchvideo/layers/convolutions.py +++ b/pytorchvideo/layers/convolutions.py @@ -11,12 +11,54 @@ class ConvReduce3D(nn.Module): """ Builds a list of convolutional operators and performs summation on the outputs. + Applies a series of 3D convolutional operations and reduces their outputs. - :: + This class takes a list of convolutional layers as input and reduces their outputs either by + summation or concatenation. + + :: Conv3d, Conv3d, ..., Conv3d ↓ - Sum + Sum + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels produced by each convolution. + kernel_size (Tuple[_size_3_t]): Tuple of sizes for the convolution kernels. + stride (Optional[Tuple[_size_3_t]]): Tuple of stride values for the convolutions. + padding (Optional[Tuple[_size_3_t]]): Tuple of padding values added to all sides of the input. + padding_mode (Optional[Tuple[str]]): Tuple of padding modes for each convolution. + Options include `zeros`, `reflect`, `replicate`, or `circular`. + dilation (Optional[Tuple[_size_3_t]]): Tuple of spacings between kernel elements. + groups (Optional[Tuple[int]]): Tuple of numbers of blocked connections from input + channels to output channels. + bias (Optional[Tuple[bool]]): If True, adds a learnable bias to the output. + reduction_method (str): Method for reducing the convolution outputs. Options + include 'sum' (summation) and 'cat' (concatenation). + + Note: + - The number of convolutional operations and their parameters depend on the length of the provided tuples. + - The `reduction_method` determines how the convolution outputs are reduced. + + Example: + ``` + # Create a ConvReduce3D instance with two convolutional layers and summation reduction. + conv_reducer = ConvReduce3D( + in_channels=3, + out_channels=64, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + padding=(1, 1, 1), + dilation=(1, 1, 1), + groups=(1, 1, 1), + bias=(True, True, True), + reduction_method="sum" + ) + + # Apply the convolutional operations to an input tensor. + output = conv_reducer(input_tensor) + ``` """ def __init__( @@ -75,6 +117,15 @@ def __init__( self.convs = nn.ModuleList(conv_list) def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass through the ConvReduce3D layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Reduced output tensor based on the chosen reduction method. + """ output = [] for ind in range(len(self.convs)): output.append(self.convs[ind](x)) @@ -106,8 +157,8 @@ def create_conv_2plus1d( activation: Callable = nn.ReLU, ) -> nn.Module: """ - Create a 2plus1d conv layer. It performs spatiotemporal Convolution, BN, and - Relu following by a spatiotemporal pooling. + Create a 2+1D convolutional layer. This layer performs spatiotemporal convolution, + followed by normalization and activation, and then spatiotemporal pooling. :: @@ -144,7 +195,11 @@ def create_conv_2plus1d( activation). Returns: - (nn.Module): 2plus1d conv layer. + nn.Module: 2+1D convolutional layer. + + Note: + - The number of convolutional operations and their parameters depend on the length of the provided tuples. + - The `conv_xy_first` argument determines the order of spatial and temporal convolutions. """ if inner_channels is None: inner_channels = out_channels @@ -190,21 +245,32 @@ def create_conv_2plus1d( class Conv2plus1d(nn.Module): """ - Implementation of 2+1d Convolution by factorizing 3D Convolution into an 1D temporal - Convolution and a 2D spatial Convolution with Normalization and Activation module - in between: + Implementation of a 2+1D convolutional layer. This layer factorizes a 3D + convolution into a 1D temporal convolution followed by a 2D spatial convolution + with normalization and activation modules in between: :: - Conv_t (or Conv_xy if conv_xy_first = True) - ↓ - Normalization - ↓ - Activation - ↓ - Conv_xy (or Conv_t if conv_xy_first = True) + Conv_t (or Conv_xy if conv_xy_first = True) + ↓ + Normalization + ↓ + Activation + ↓ + Conv_xy (or Conv_t if conv_xy_first = True) - The 2+1d Convolution is used to build the R(2+1)D network. + The 2+1D convolution is commonly used to build the R(2+1)D network. + + Args: + conv_t (torch.nn.Module): Temporal convolution module. + norm (torch.nn.Module): Normalization module. + activation (torch.nn.Module): Activation module. + conv_xy (torch.nn.Module): Spatial convolution module. + conv_xy_first (bool): If True, spatial convolution comes before temporal convolution. + + Note: + - The provided modules define the components of the 2+1D convolution layer. + - The `conv_xy_first` argument determines the order of spatial and temporal convolutions. """ def __init__( @@ -217,12 +283,14 @@ def __init__( conv_xy_first: bool = False, ) -> None: """ + Initialize the Conv2plus1d layer. + Args: - conv_t (torch.nn.modules): temporal convolution module. - norm (torch.nn.modules): normalization module. - activation (torch.nn.modules): activation module. - conv_xy (torch.nn.modules): spatial convolution module. - conv_xy_first (bool): If True, spatial convolution comes before temporal conv + conv_t (torch.nn.Module): Temporal convolution module. + norm (torch.nn.Module): Normalization module. + activation (torch.nn.Module): Activation module. + conv_xy (torch.nn.Module): Spatial convolution module. + conv_xy_first (bool): If True, spatial convolution comes before temporal convolution. """ super().__init__() set_attributes(self, locals()) @@ -230,6 +298,15 @@ def __init__( assert self.conv_xy is not None def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Perform a forward pass through the Conv2plus1d layer. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying the convolution and intermediate modules. + """ x = self.conv_xy(x) if self.conv_xy_first else self.conv_t(x) x = self.norm(x) if self.norm else x x = self.activation(x) if self.activation else x diff --git a/pytorchvideo/layers/distributed.py b/pytorchvideo/layers/distributed.py index 5e309b38..cb09392a 100644 --- a/pytorchvideo/layers/distributed.py +++ b/pytorchvideo/layers/distributed.py @@ -12,8 +12,10 @@ def get_world_size() -> int: """ - Simple wrapper for correctly getting worldsize in both distributed - / non-distributed settings + Get the total world size, accounting for distributed or non-distributed settings. + + Returns: + int: Total world size (number of processes). """ return ( torch.distributed.get_world_size() @@ -23,35 +25,55 @@ def get_world_size() -> int: def cat_all_gather(tensors, local=False): - """Performs the concatenated all_reduce operation on the provided tensors.""" + """ + Perform the concatenated all-gather operation on the provided tensors. + + Args: + tensors (torch.Tensor): The tensor(s) to gather and concatenate. + local (bool): If True, gather within the local process group. + + Returns: + torch.Tensor: The concatenated result tensor. + """ if local: gather_sz = get_local_size() else: gather_sz = torch.distributed.get_world_size() + tensors_gather = [torch.ones_like(tensors) for _ in range(gather_sz)] + torch.distributed.all_gather( tensors_gather, tensors, async_op=False, group=_LOCAL_PROCESS_GROUP if local else None, ) + output = torch.cat(tensors_gather, dim=0) return output def init_distributed_training(num_gpus, shard_id): """ - Initialize variables needed for distributed training. + Initialize variables required for distributed training. + + Args: + num_gpus (int): The number of GPUs per machine. + shard_id (int): The shard ID of the current machine. """ if num_gpus <= 1: return + num_gpus_per_machine = num_gpus num_machines = dist.get_world_size() // num_gpus_per_machine + for i in range(num_machines): ranks_on_i = list( range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine) ) + pg = dist.new_group(ranks_on_i) + if i == shard_id: global _LOCAL_PROCESS_GROUP _LOCAL_PROCESS_GROUP = pg @@ -59,9 +81,10 @@ def init_distributed_training(num_gpus, shard_id): def get_local_size() -> int: """ + Get the size of the per-machine process group, i.e., the number of processes per machine. + Returns: - The size of the per-machine process group, - i.e. the number of processes per machine. + int: The size of the per-machine process group. """ if not dist.is_available(): return 1 @@ -72,8 +95,10 @@ def get_local_size() -> int: def get_local_rank() -> int: """ + Get the rank of the current process within the local (per-machine) process group. + Returns: - The rank of the current process within the local (per-machine) process group. + int: The rank of the current process within the local process group. """ if not dist.is_available(): return 0 @@ -84,6 +109,12 @@ def get_local_rank() -> int: def get_local_process_group() -> ProcessGroup: + """ + Get the local (per-machine) process group. + + Returns: + ProcessGroup: The local process group. + """ assert _LOCAL_PROCESS_GROUP is not None return _LOCAL_PROCESS_GROUP diff --git a/pytorchvideo/layers/drop_path.py b/pytorchvideo/layers/drop_path.py index 8023e10b..9465cbbd 100644 --- a/pytorchvideo/layers/drop_path.py +++ b/pytorchvideo/layers/drop_path.py @@ -8,32 +8,48 @@ def drop_path( x: torch.Tensor, drop_prob: float = 0.0, training: bool = False ) -> torch.Tensor: """ - Stochastic Depth per sample. + Apply stochastic depth regularization to the input tensor. + + Stochastic Depth is a regularization technique used in deep neural networks. + During training, it randomly drops (sets to zero) a fraction of the input tensor + elements to prevent overfitting. During inference, no elements are dropped. Args: - x (tensor): Input tensor. - drop_prob (float): Probability to apply drop path. - training (bool): If True, apply drop path to input. Otherwise (tesing), return input. + x (torch.Tensor): Input tensor. + drop_prob (float): Probability to apply drop path (0.0 means no drop). + training (bool): If True, apply drop path during training; otherwise, return the input. + + Returns: + torch.Tensor: Output tensor after applying drop path. """ if drop_prob == 0.0 or not training: + # If drop probability is 0 or not in training mode, return the input as is. return x + keep_prob = 1 - drop_prob - shape = (x.shape[0],) + (1,) * ( - x.ndim - 1 - ) # work with diff dim tensors, not just 2D ConvNets + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # Adjust shape for various tensor dimensions. mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) - mask.floor_() # binarize + mask.floor_() # Binarize the mask. + + # Scale the input tensor and apply the mask to drop elements. output = x.div(keep_prob) * mask + return output class DropPath(nn.Module): """ Drop paths (Stochastic Depth) per sample. + + Drop path is a regularization technique used in deep neural networks to + randomly drop (set to zero) a fraction of input tensor elements during training + to prevent overfitting. """ def __init__(self, drop_prob: float = 0.0) -> None: """ + Initialize the DropPath module. + Args: drop_prob (float): Probability to apply drop path. """ @@ -42,7 +58,13 @@ def __init__(self, drop_prob: float = 0.0) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: """ + Apply drop path regularization to the input tensor. + Args: - x (tensor): Input tensor. + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor after applying drop path. """ + # Call the drop_path function to apply drop path to the input tensor. return drop_path(x, self.drop_prob, self.training) diff --git a/pytorchvideo/layers/fusion.py b/pytorchvideo/layers/fusion.py index 9656bec4..895c7930 100644 --- a/pytorchvideo/layers/fusion.py +++ b/pytorchvideo/layers/fusion.py @@ -16,6 +16,12 @@ def make_fusion_layer(method: str, feature_dims: List[int]): """ + Drop paths (Stochastic Depth) per sample. + + Drop path is a regularization technique used in deep neural networks to + randomly drop (set to zero) a fraction of input tensor elements during training + to prevent overfitting. + Args: method (str): the fusion method to be constructed. Options: - 'concat' @@ -28,6 +34,11 @@ def make_fusion_layer(method: str, feature_dims: List[int]): of required feature_dims for each tensor input (where the tensor inputs are of shape (batch_size, seq_len, feature_dim)). The list order must corresponds to the tensor order passed to forward(...). + + Example: + >>> drop_path_layer = DropPath(drop_prob=0.2) + >>> input_tensor = torch.rand(1, 64, 128, 128) # Example input tensor + >>> output_tensor = drop_path_layer(input_tensor) """ if method == "concat": return ConcatFusion(feature_dims) @@ -45,11 +56,29 @@ def make_fusion_layer(method: str, feature_dims: List[int]): class ConcatFusion(nn.Module): """ - Concatenates all inputs by their last dimension. The resulting tensor last dim will be - the sum of the last dimension of all input tensors. + Concatenates multiple input tensors along their last dimension to create a fused tensor. + The size of the last dimension in the resulting tensor is the sum of the last dimensions + of all input tensors. + + Args: + feature_dims (List[int]): A list of feature dimensions for each input tensor. + + Attributes: + output_dim (int): The size of the last dimension in the fused tensor. + + Example: + If feature_dims is [64, 128, 256], and three tensors of shape (batch_size, seq_len, 64), + (batch_size, seq_len, 128), and (batch_size, seq_len, 256) are concatenated, the output + tensor will have a shape of (batch_size, seq_len, 448) because 64 + 128 + 256 = 448. """ def __init__(self, feature_dims: List[int]): + """ + Initialize the ConcatFusion module. + + Args: + feature_dims (List[int]): A list of feature dimensions for each input tensor. + """ super().__init__() _verify_feature_dim(feature_dims) self._output_dim = sum(feature_dims) @@ -57,29 +86,51 @@ def __init__(self, feature_dims: List[int]): @property def output_dim(self): """ - Last dimension size of forward(..) tensor output. + Get the size of the last dimension in the fused tensor. + + Returns: + int: The size of the last dimension in the fused tensor. """ return self._output_dim def forward(self, input_list: List[torch.Tensor]) -> torch.Tensor: """ + Concatenate a list of input tensors along their last dimension. + Args: - input_list (List[torch.Tensor]): a list of tensors of shape - (batch_size, seq_len, feature_dim). + input_list (List[torch.Tensor]): A list of tensors to be concatenated. Returns: - Tensor of shape (batch_size, seq_len, sum(feature_dims)) where sum(feature_dims) - is the sum of all input feature_dims. + torch.Tensor: A tensor resulting from the concatenation of input tensors. + The size of the last dimension is the sum of the feature dimensions + of all input tensors. """ return torch.cat(input_list, dim=-1) class TemporalConcatFusion(nn.Module): """ - Concatenates all inputs by their temporal dimension which is assumed to be dim=1. + Concatenates input tensors along their temporal dimension (assumed to be dim=1). + + This module takes a list of input tensors, each with shape (batch_size, seq_len, feature_dim), + and concatenates them along the temporal dimension (dim=1). + + Args: + feature_dims (List[int]): List of feature dimensions of the input tensors. + + Note: + - All input tensors must have the same feature dimension. + - The output tensor will have shape (batch_size, sum(seq_len), feature_dim), + where sum(seq_len) is the sum of seq_len for all input tensors. """ def __init__(self, feature_dims: List[int]): + """ + Initialize the TemporalConcatFusion module. + + Args: + feature_dims (List[int]): List of feature dimensions of the input tensors. + """ super().__init__() _verify_feature_dim(feature_dims) @@ -90,28 +141,45 @@ def __init__(self, feature_dims: List[int]): @property def output_dim(self): """ - Last dimension size of forward(..) tensor output. + Get the last dimension size of the output tensor produced by the forward(..) method. + + Returns: + int: Last dimension size of the forward(..) tensor output. """ return self._output_dim def forward(self, input_list: List[torch.Tensor]) -> torch.Tensor: """ + Perform forward pass through the TemporalConcatFusion module. + Args: - input_list (List[torch.Tensor]): a list of tensors of shape - (batch_size, seq_len, feature_dim) + input_list (List[torch.Tensor]): A list of tensors of shape + (batch_size, seq_len, feature_dim). Returns: - Tensor of shape (batch_size, sum(seq_len), feature_dim) where sum(seq_len) is - the sum of all input tensors. + torch.Tensor: Output tensor of shape (batch_size, sum(seq_len), feature_dim), + where sum(seq_len) is the sum of all input tensors' seq_len. """ return torch.cat(input_list, dim=1) class ReduceFusion(nn.Module): """ - Generic fusion method which takes a callable which takes the list of input tensors - and expects a single tensor to be used. This class can be used to implement fusion - methods like "sum", "max" and "prod". + A generic fusion method that applies a specified reduction function to a list of input tensors + to produce a single output tensor. This class can be used to implement fusion methods like "sum", + "max", and "prod". + + Args: + feature_dims (List[int]): List of feature dimensions for the input tensors. + reduce_fn (Callable[[torch.Tensor], torch.Tensor]): A callable reduction function that takes + the list of input tensors and returns a single tensor. + + Attributes: + output_dim (int): The dimension of the output tensor after fusion, which is the maximum + of the input feature dimensions. + + Note: + - The input tensors must have consistent feature dimensions for fusion to work correctly. """ def __init__( @@ -122,28 +190,31 @@ def __init__( self.reduce_fn = reduce_fn # All input dimensions must be the same - self._output_dim = max(feature_dims) - assert self._output_dim == min(feature_dims) - - @property - def output_dim(self): - """ - Last dimension size of forward(..) tensor output. - """ - return self._output_dim + self.output_dim = max(feature_dims) + assert self.output_dim == min(feature_dims) def forward(self, input_list: List[torch.Tensor]) -> torch.Tensor: """ + Forward pass of the ReduceFusion module. + Args: - input_list (List[torch.Tensor]): a list of tensors of shape - (batch_size, seq_len, feature_dim). + input_list (List[torch.Tensor]): A list of tensors to be fused. Returns: - Tensor of shape (batch_size, seq_len, feature_dim). + torch.Tensor: The fused tensor after applying the reduction function. """ return self.reduce_fn(torch.stack(input_list)) def _verify_feature_dim(feature_dims: List[int]): + """ + Verify that the feature dimensions in the list are valid. + + Args: + feature_dims (List[int]): List of feature dimensions. + + Raises: + AssertionError: If any feature dimension is non-positive or if the list is empty. + """ assert isinstance(feature_dims, list) - assert all(x > 0 for x in feature_dims) + assert all(x > 0 for x in feature_dims) \ No newline at end of file diff --git a/pytorchvideo/layers/mlp.py b/pytorchvideo/layers/mlp.py index 78556e77..f13fad10 100644 --- a/pytorchvideo/layers/mlp.py +++ b/pytorchvideo/layers/mlp.py @@ -13,24 +13,52 @@ def make_multilayer_perceptron( dropout_rate: float = 0.0, ) -> Tuple[nn.Module, int]: """ - Factory function for Multi-Layer Perceptron. These are constructed as repeated - blocks of the following format where each fc represents the blocks output/input dimension. - - :: - - Linear (in=fc[i-1], out=fc[i]) - ↓ - Normalization (norm) - ↓ - Activation (mid_activation) - ↓ - After the repeated Perceptron blocks, - a final dropout and activation layer is applied: - ↓ - Dropout (p=dropout_rate) - ↓ - Activation (final_activation) + Create a Multi-Layer Perceptron (MLP) with customizable architecture. + Args: + fully_connected_dims (List[int]): A list of integers specifying the dimensions of + fully connected layers in the MLP. The list should have at least two elements, + where the first element is the input dimension and the last element is the output + dimension. + + norm (Optional[Callable]): A callable normalization function to be applied after each + fully connected layer (e.g., nn.BatchNorm1d). If None, no normalization is applied. + + mid_activation (Callable): A callable activation function to be applied after each + fully connected layer except the last one (e.g., nn.ReLU). + + final_activation (Optional[Callable]): A callable activation function to be applied + after the last fully connected layer. If None, no activation is applied after + the final layer. + + dropout_rate (float): The dropout rate to be applied after each fully connected layer. + If 0.0, no dropout is applied. + + Returns: + Tuple[nn.Module, int]: A tuple containing the MLP module and the output dimension. + + Example: + To create a simple MLP with two hidden layers of size 64 and 32: + ``` + mlp, output_dim = make_multilayer_perceptron( + fully_connected_dims=[input_dim, 64, 32, output_dim], + norm=nn.BatchNorm1d, + mid_activation=nn.ReLU, + final_activation=nn.Sigmoid, + dropout_rate=0.1 + ) + ``` + + Note: + - The `fully_connected_dims` list must have at least two elements, with the first + element representing the input dimension and the last element representing the output + dimension. + - You can customize the architecture of the MLP by specifying the number of hidden layers + and their dimensions. + - Activation functions are applied after each hidden layer, except for the final layer. + - If `norm` is provided, it is applied after each hidden layer. + - If `dropout_rate` is greater than 0.0, dropout is applied after each hidden layer. + - The final activation function is applied after the last hidden layer if provided. """ assert isinstance(fully_connected_dims, list) assert len(fully_connected_dims) > 1 diff --git a/pytorchvideo/layers/nonlocal_net.py b/pytorchvideo/layers/nonlocal_net.py index a6ae91f5..57f25bd9 100644 --- a/pytorchvideo/layers/nonlocal_net.py +++ b/pytorchvideo/layers/nonlocal_net.py @@ -9,15 +9,47 @@ class NonLocal(nn.Module): """ - Builds Non-local Neural Networks as a generic family of building - blocks for capturing long-range dependencies. Non-local Network - computes the response at a position as a weighted sum of the - features at all positions. This building block can be plugged into - many computer vision architectures. - More details in the paper: + Implementation of Non-local Neural Networks, which capture long-range dependencies + in feature maps. Non-local Network computes the response at a position as a weighted + sum of the features at all positions This building block can be integrated into + various computer vision architectures. + + Reference: Wang, Xiaolong, Ross Girshick, Abhinav Gupta, and Kaiming He. "Non-local neural networks." In Proceedings of the IEEE conference on CVPR, 2018. + + Args: + conv_theta (nn.Module): Convolutional layer for computing the 'theta' transformation. + conv_phi (nn.Module): Convolutional layer for computing the 'phi' transformation. + conv_g (nn.Module): Convolutional layer for computing the 'g' transformation. + conv_out (nn.Module): Convolutional layer for the output transformation. + pool (Optional[nn.Module]): Optional temporal-spatial pooling layer to reduce computation. + norm (Optional[nn.Module]): Optional normalization layer to be applied to the output. + instantiation (str): The type of normalization used. Options are 'dot_product' and 'softmax'. + + Note: + - The 'conv_theta', 'conv_phi', 'conv_g', and 'conv_out' modules should have + matching output and input dimensions. + - 'pool' can be used for temporal-spatial pooling to reduce computation. + - 'instantiation' determines the type of normalization applied to the affinity tensor. + + Example: + To create a Non-local block: + ``` + non_local_block = NonLocal( + conv_theta=nn.Conv3d(in_channels, inner_channels, kernel_size=1), + conv_phi=nn.Conv3d(in_channels, inner_channels, kernel_size=1), + conv_g=nn.Conv3d(in_channels, inner_channels, kernel_size=1), + conv_out=nn.Conv3d(inner_channels, in_channels, kernel_size=1), + pool=nn.MaxPool3d(kernel_size=(1, 2, 2)), + norm=nn.BatchNorm3d(inner_channels), + instantiation='dot_product' + ) + ``` + + Returns: + torch.Tensor: The output tensor with long-range dependencies captured. """ def __init__( @@ -105,27 +137,54 @@ def create_nonlocal( norm_momentum: float = 0.1, ): """ - Builds Non-local Neural Networks as a generic family of building - blocks for capturing long-range dependencies. Non-local Network - computes the response at a position as a weighted sum of the - features at all positions. This building block can be plugged into + Create a Non-local Neural Network block for capturing long-range dependencies in computer + vision architectures.Non-local Network computes the response at a position as a weighted + sum of the features at all positions. This building block can be plugged into many computer vision architectures. More details in the paper: https://arxiv.org/pdf/1711.07971 + Args: - dim_in (int): number of dimension for the input. - dim_inner (int): number of dimension inside of the Non-local block. - pool_size (tuple[int]): the kernel size of spatial temporal pooling, - temporal pool kernel size, spatial pool kernel size, spatial pool kernel - size in order. By default pool_size is None, then there would be no pooling - used. - instantiation (string): supports two different instantiation method: - "dot_product": normalizing correlation matrix with L2. - "softmax": normalizing correlation matrix with Softmax. - norm (nn.Module): nn.Module for the normalization layer. The default is + dim_in (int): The number of dimensions for the input. + dim_inner (int): The number of dimensions inside the Non-local block. + pool_size (tuple[int]): The kernel size of spatial-temporal pooling. It consists of + three integers representing the temporal pool kernel size, spatial pool kernel + width, and spatial pool kernel height, respectively. If set to (1, 1, 1), no + pooling is used. Default is (1, 1, 1). + instantiation (string): The instantiation method for normalizing the correlation + matrix. Supports two options: "dot_product" (normalizing correlation matrix + with L2) and "softmax" (normalizing correlation matrix with Softmax). + norm (nn.Module): An instance of nn.Module for the normalization layer. Default is nn.BatchNorm3d. - norm_eps (float): normalization epsilon. - norm_momentum (float): normalization momentum. + norm_eps (float): The normalization epsilon. + norm_momentum (float): The normalization momentum. + + Returns: + NonLocal: A Non-local Neural Network block that can be integrated into computer + vision architectures. + + Example: + To create a Non-local block with a temporal pool size of 2x2x2 and "dot_product" + instantiation: + ``` + non_local_block = create_nonlocal( + dim_in=256, + dim_inner=128, + pool_size=(2, 2, 2), + instantiation="dot_product", + norm=nn.BatchNorm3d, + norm_eps=1e-5, + norm_momentum=0.1 + ) + ``` + + Note: + - The Non-local block is a useful building block for capturing long-range + dependencies in computer vision tasks. + - You can customize the architecture of the Non-local block by specifying the + input dimension (`dim_in`), inner dimension (`dim_inner`), pooling size + (`pool_size`), and normalization settings (`norm`, `norm_eps`, and `norm_momentum`). """ + if pool_size is None: pool_size = (1, 1, 1) assert isinstance(pool_size, Iterable) diff --git a/pytorchvideo/layers/positional_encoding.py b/pytorchvideo/layers/positional_encoding.py index 1f67e9e3..a3a957b1 100644 --- a/pytorchvideo/layers/positional_encoding.py +++ b/pytorchvideo/layers/positional_encoding.py @@ -10,18 +10,34 @@ class PositionalEncoding(nn.Module): """ - Applies a positional encoding to a tensor with shape (batch_size x seq_len x embed_dim). + Applies positional encoding to a tensor with shape (batch_size x seq_len x embed_dim). + + Positional encoding is a crucial component in transformers. It helps the model + capture the sequential order of the input data. The positional encoding is computed as follows: - PE(pos,2i) = sin(pos/10000^(2i/dmodel)) - PE(pos,2i+1) = cos(pos/10000^(2i/dmodel)) + PE(pos, 2i) = sin(pos / 10000^(2i / dmodel)) + PE(pos, 2i+1) = cos(pos / 10000^(2i / dmodel)) - where pos = position, pos in [0, seq_len) - dmodel = data embedding dimension = embed_dim - i = dimension index, i in [0, embed_dim) + where: + - pos: position in the sequence, pos in [0, seq_len) + - dmodel: data embedding dimension = embed_dim + - i: dimension index, i in [0, embed_dim) Reference: "Attention Is All You Need" https://arxiv.org/abs/1706.03762 Implementation Reference: https://pytorch.org/tutorials/beginner/transformer_tutorial.html + + Args: + embed_dim (int): The embedding dimension of the input data. + seq_len (int): The maximum sequence length for which the positional encoding is calculated. + + Attributes: + pe (torch.Tensor): The precomputed positional encodings. + + Example: + >>> positional_encoder = PositionalEncoding(512, 1000) + >>> input_data = torch.randn(32, 100, 512) + >>> output_data = positional_encoder(input_data) """ def __init__(self, embed_dim: int, seq_len: int = 1024) -> None: @@ -140,19 +156,25 @@ def get_3d_sincos_pos_embed( embed_dim: int, grid_size: int, t_size: int, cls_token: bool = False ) -> torch.Tensor: """ - Get 3D sine-cosine positional embedding. + Get 3D sine-cosine positional embedding for a 3D grid. + Args: - grid_size: int of the grid height and width - t_size: int of the temporal size - cls_token: bool, whether to contain CLS token + embed_dim (int): The total embedding dimension. It should be divisible by 4. + The embedding is split into three parts: spatial (3/4 of embed_dim) and + temporal (1/4 of embed_dim). + grid_size (int): The size of the grid in both height and width dimensions. + t_size (int): The temporal size of the grid. + cls_token (bool): Whether to include a CLS token in the output. + Returns: - (torch.Tensor): [t_size*grid_size*grid_size, embed_dim] or [1+t_size*grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + torch.Tensor: A positional embedding tensor of shape [t_size*grid_size*grid_size, embed_dim] + if cls_token is False, or [1 + t_size*grid_size*grid_size, embed_dim] if cls_token is True. """ assert embed_dim % 4 == 0 embed_dim_spatial = embed_dim // 4 * 3 embed_dim_temporal = embed_dim // 4 - # spatial + # Generate spatial positional embeddings grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) @@ -161,7 +183,7 @@ def get_3d_sincos_pos_embed( grid = grid.reshape([2, 1, grid_size, grid_size]) pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) - # temporal + # Generate temporal positional embeddings grid_t = np.arange(t_size, dtype=np.float32) pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) @@ -175,19 +197,24 @@ def get_3d_sincos_pos_embed( if cls_token: pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed + + return torch.tensor(pos_embed) def get_2d_sincos_pos_embed( embed_dim: int, grid_size: int, cls_token: bool = False ) -> torch.Tensor: """ - Get 2D sine-cosine positional embedding. + Get 2D sine-cosine positional embedding for a 2D grid. + Args: - grid_size: int of the grid height and width - cls_token: bool, whether to contain CLS token + embed_dim (int): The embedding dimension. + grid_size (int): The grid height and width. + cls_token (bool): Whether to include a CLS token. + Returns: - (torch.Tensor): [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + torch.Tensor: A 2D positional embedding tensor of shape [grid_size*grid_size, embed_dim] + or [1+grid_size*grid_size, embed_dim] if cls_token is True. """ grid_h = np.arange(grid_size, dtype=np.float32) grid_w = np.arange(grid_size, dtype=np.float32) @@ -203,12 +230,15 @@ def get_2d_sincos_pos_embed( def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray) -> torch.Tensor: """ - Get 2D sine-cosine positional embedding from grid. + Get 2D sine-cosine positional embedding from a grid. + Args: - embed_dim: embedding dimension. - grid: positions + embed_dim (int): The embedding dimension. + grid (np.ndarray): A 2D grid of positions. + Returns: - (torch.Tensor): [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + torch.Tensor: A 2D positional embedding tensor of shape [grid_size*grid_size, embed_dim] + or [1+grid_size*grid_size, embed_dim] if cls_token is True. """ assert embed_dim % 2 == 0 @@ -222,12 +252,16 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: np.ndarray) -> torch def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: np.ndarray) -> torch.Tensor: """ - Get 1D sine-cosine positional embedding. + Get 1D sine-cosine positional embedding for a 1D array of positions. + Args: - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) + embed_dim (int): The output dimension for each position. + pos (np.ndarray): A 1D array of positions to be encoded. + Returns: - (torch.Tensor): tensor of shape (M, D) + torch.Tensor: A tensor of shape (M, D), where M is the number of positions + and D is the embedding dimension. + """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=float) diff --git a/pytorchvideo/layers/positional_encoding_torchscript.py b/pytorchvideo/layers/positional_encoding_torchscript.py index 5f3d328d..3f795175 100644 --- a/pytorchvideo/layers/positional_encoding_torchscript.py +++ b/pytorchvideo/layers/positional_encoding_torchscript.py @@ -8,7 +8,29 @@ class ScriptableSpatioTemporalClsPositionalEncoding(nn.Module): """ - Add a cls token and apply a spatiotemporal encoding to a tensor. + Add a cls token and apply spatiotemporal encoding to a tensor. + + This module is used for positional encoding in spatiotemporal models. + It adds positional embeddings to the input tensor, which can be separated + into spatial and temporal components or kept as a single positional encoding. + + Args: + embed_dim (int): The embedding dimension for the input sequence. + patch_embed_shape (Tuple[int, int, int]): The number of patches in each dimension + (T, H, W) after patch embedding. + sep_pos_embed (bool): If set to True, separate positional encodings are used for + spatial patches and temporal sequences. Otherwise, a single positional encoding + is used for all patches. + has_cls (bool): If set to True, a cls token is added to the beginning of each + input sequence. + + Note: + - `patch_embed_shape` should be provided as a tuple in the form (T, H, W). + - When `sep_pos_embed` is set to True, two positional encodings are used: one for + spatial patches and one for temporal sequences. Otherwise, only one positional + encoding is used for all patches. + - If `has_cls` is set to True, a cls token is added to the beginning of each input + sequence. """ def __init__( @@ -18,34 +40,18 @@ def __init__( sep_pos_embed: bool = False, has_cls: bool = True, ) -> None: - """ - Args: - embed_dim (int): Embedding dimension for input sequence. - patch_embed_shape (Tuple): The number of patches in each dimension - (T, H, W) after patch embedding. - sep_pos_embed (bool): If set to true, one positional encoding is used for - spatial patches and another positional encoding is used for temporal - sequence. Otherwise, only one positional encoding is used for all the - patches. - has_cls (bool): If set to true, a cls token is added in the beginning of each - input sequence. - """ + super().__init__() - assert ( - len(patch_embed_shape) == 3 - ), "Patch_embed_shape should be in the form of (T, H, W)." - assert not has_cls + assert len(patch_embed_shape) == 3, "Patch_embed_shape should be in the form of (T, H, W)." + assert not has_cls # This implementation currently does not support cls token. self.sep_pos_embed = sep_pos_embed self._patch_embed_shape = patch_embed_shape self.num_spatial_patch = patch_embed_shape[1] * patch_embed_shape[2] self.num_temporal_patch = patch_embed_shape[0] - self.pos_embed_spatial = nn.Parameter( - torch.zeros(1, self.num_spatial_patch, embed_dim) - ) - self.pos_embed_temporal = nn.Parameter( - torch.zeros(1, self.num_temporal_patch, embed_dim) - ) + # Initialize spatial and temporal positional embeddings. + self.pos_embed_spatial = nn.Parameter(torch.zeros(1, self.num_spatial_patch, embed_dim)) + self.pos_embed_temporal = nn.Parameter(torch.zeros(1, self.num_temporal_patch, embed_dim)) @property def patch_embed_shape(self): @@ -53,18 +59,20 @@ def patch_embed_shape(self): def forward(self, x: torch.Tensor) -> torch.Tensor: """ + Apply spatiotemporal positional encoding to the input tensor. + Args: x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor with positional encoding applied. """ B, N, C = x.shape assert self.sep_pos_embed - pos_embed = self.pos_embed_spatial.repeat( - 1, self.num_temporal_patch, 1 - ) + torch.repeat_interleave( - self.pos_embed_temporal, - self.num_spatial_patch, - dim=1, - ) + pos_embed = self.pos_embed_spatial.repeat(1, self.num_temporal_patch, 1) + \ + torch.repeat_interleave(self.pos_embed_temporal, self.num_spatial_patch, dim=1) + x = x + pos_embed return x + \ No newline at end of file diff --git a/pytorchvideo/layers/squeeze_excitation.py b/pytorchvideo/layers/squeeze_excitation.py index 47858e0a..d81201c8 100644 --- a/pytorchvideo/layers/squeeze_excitation.py +++ b/pytorchvideo/layers/squeeze_excitation.py @@ -8,7 +8,21 @@ class SqueezeAndExcitationLayer2D(nn.Module): - """2D Squeeze and excitation layer, as per https://arxiv.org/pdf/1709.01507.pdf""" + """ + 2D Squeeze and Excitation (SE) layer as described in the paper: + "Squeeze-and-Excitation Networks" (https://arxiv.org/pdf/1709.01507.pdf). + + This layer enhances the representational power of a convolutional neural network + by adaptively recalibrating the feature maps. + + Args: + in_planes (int): Input channel dimension. + reduction_ratio (int): The reduction ratio used to reduce the input channels + before scaling. It controls the number of parameters in the SE layer. + reduced_planes (Optional[int]): Output channel dimension. If specified, it + overrides the reduction_ratio. Only one of reduction_ratio or reduced_planes + should be defined. + """ def __init__( self, @@ -16,26 +30,19 @@ def __init__( reduction_ratio: Optional[int] = 16, reduced_planes: Optional[int] = None, ): - - """ - Args: - in_planes (int): input channel dimension. - reduction_ratio (int): factor by which in_planes should be reduced to - get the output channel dimension. - reduced_planes (int): Output channel dimension. Only one of reduction_ratio - or reduced_planes should be defined. - """ + super().__init__() self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Either reduction_ratio is defined, or out_planes is defined assert bool(reduction_ratio) != bool( reduced_planes - ), "Only of reduction_ratio or reduced_planes should be defined for SE layer" + ), "Only one of reduction_ratio or reduced_planes should be defined for SE layer" reduced_planes = ( in_planes // reduction_ratio if reduced_planes is None else reduced_planes ) + self.excitation = nn.Sequential( nn.Conv2d(in_planes, reduced_planes, kernel_size=1, stride=1, bias=True), nn.ReLU(), @@ -45,8 +52,14 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """ + Forward pass of the Squeeze and Excitation layer. + Args: - x (tensor): 2D image of format C * H * W + x (torch.Tensor): Input feature map of shape (batch_size, in_planes, H, W). + + Returns: + torch.Tensor: Output feature map after applying the SE operation. + Same shape as the input. """ x_squeezed = self.avgpool(x) x_excited = self.excitation(x_squeezed) @@ -132,7 +145,22 @@ def create_audio_2d_squeeze_excitation_block( and None (not performing activation). Returns: - (nn.Module): resnet basic block layer. + nn.Module: A 2D Residual block with Squeeze-and-Excitation for audio processing. + + Example: + To create a SE2D block with default settings: + ``` + block = create_audio_2d_squeeze_excitation_block( + dim_in=64, + dim_out=128, + use_se=True, + ) + ``` + + Note: + - This block performs spatial and feature transformations using two convolutional + branches. + - The SE layer, if enabled, recalibrates channel-wise feature importance. """ branch2 = [ diff --git a/pytorchvideo/layers/swish.py b/pytorchvideo/layers/swish.py index a4e8198b..bd235b52 100644 --- a/pytorchvideo/layers/swish.py +++ b/pytorchvideo/layers/swish.py @@ -6,7 +6,43 @@ class Swish(nn.Module): """ - Wrapper for the Swish activation function. + Swish activation function: x * sigmoid(x). + + Swish is a non-linear activation function that has shown promising results in + neural network architectures. It is defined as the element-wise product of + the input tensor and the sigmoid of the input tensor. + + References: + - "Searching for activation functions" by Prajit Ramachandran, Barret Zoph, and Quoc V. Le (2017) + + Example: + ```python + activation = Swish() + output = activation(input_tensor) + ``` + + Note: + - The Swish function has been found to be effective in various deep learning tasks. + - It is differentiable and often produces smoother gradients compared to ReLU. + + Args: + None + + Returns: + torch.Tensor: The tensor after applying the Swish activation. + + Shape: + - Input: Any shape as long as it is broadcastable to the output shape. + - Output: Same shape as the input. + + Examples: + >>> activation = Swish() + >>> input_tensor = torch.tensor([1.0, 2.0, 3.0]) + >>> output = activation(input_tensor) + >>> output + tensor([0.7311, 1.7616, 2.9466]) + ``` + """ def forward(self, x): @@ -15,10 +51,14 @@ def forward(self, x): class SwishFunction(torch.autograd.Function): """ - Implementation of the Swish activation function: x * sigmoid(x). + Autograd function for the Swish activation. + + Args: + - ctx (context): A context object to save information for backward pass. + - x (Tensor): The input tensor. - Searching for activation functions. Ramachandran, Prajit and Zoph, Barret - and Le, Quoc V. 2017 + Returns: + - result (Tensor): The output tensor after applying the Swish activation. """ @staticmethod @@ -32,3 +72,4 @@ def backward(ctx, grad_output): (x,) = ctx.saved_tensors sigmoid_x = torch.sigmoid(x) return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x))) + \ No newline at end of file diff --git a/pytorchvideo/layers/utils.py b/pytorchvideo/layers/utils.py index 15593d61..fd3888a4 100644 --- a/pytorchvideo/layers/utils.py +++ b/pytorchvideo/layers/utils.py @@ -6,25 +6,27 @@ def set_attributes(self, params: List[object] = None) -> None: """ - An utility function used in classes to set attributes from the input list of parameters. + Utility function used in classes to set attributes from the input dictionary of parameters. + Args: - params (list): list of parameters. + params: A List of attribute names and their corresponding values. """ - if params: - for k, v in params.items(): - if k != "self": - setattr(self, k, v) + for key, value in params.items(): + setattr(self, key, value) - -def round_width(width, multiplier, min_width=8, divisor=8, ceil=False): +def round_width(width: int, multiplier: float, min_width: int = 8, divisor: int = 8, ceil: bool = False) -> int: """ - Round width of filters based on width multiplier + Round the width of filters based on a width multiplier. + Args: - width (int): the channel dimensions of the input. - multiplier (float): the multiplication factor. - min_width (int): the minimum width after multiplication. - divisor (int): the new width should be dividable by divisor. - ceil (bool): If True, use ceiling as the rounding method. + width (int): The channel dimensions of the input. + multiplier (float): The multiplication factor. + min_width (int, optional): The minimum width after multiplication. + divisor (int, optional): The new width should be divisible by divisor. + ceil (bool, optional): If True, use ceiling as the rounding method. + + Returns: + int: The rounded width value. """ if not multiplier: return width @@ -39,10 +41,16 @@ def round_width(width, multiplier, min_width=8, divisor=8, ceil=False): width_out += divisor return int(width_out) - -def round_repeats(repeats, multiplier): +def round_repeats(repeats: int, multiplier: float) -> int: """ - Round number of layers based on depth multiplier. + Round the number of layers based on a depth multiplier. + + Args: + repeats (int): The original number of layers. + multiplier (float): The depth multiplier. + + Returns: + int: The rounded number of layers. """ if not multiplier: return repeats