diff --git a/simlarity/feature_extraction.py b/simlarity/feature_extraction.py index 547972e..c87aed2 100644 --- a/simlarity/feature_extraction.py +++ b/simlarity/feature_extraction.py @@ -179,11 +179,7 @@ def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathT def _get_leaf_modules_for_ops() -> List[type]: members = inspect.getmembers(torchvision.ops) - result = [] - for _, obj in members: - if inspect.isclass(obj) and issubclass(obj, torch.nn.Module): - result.append(obj) - return result + return [obj for _, obj in members if inspect.isclass(obj) and issubclass(obj, torch.nn.Module)] def get_graph_node_names( @@ -496,10 +492,7 @@ def to_strdict(n) -> Dict[str, str]: ) # Remove existing output nodes (train mode) - orig_output_nodes = [] - for n in reversed(graph_module.graph.nodes): - if n.op == "output": - orig_output_nodes.append(n) + orig_output_nodes = [n for n in reversed(graph_module.graph.nodes) if n.op == "output"] assert len(orig_output_nodes) for n in orig_output_nodes: graph_module.graph.erase_node(n) @@ -654,19 +647,13 @@ def to_strdict(n) -> Dict[str, str]: ) # Remove existing output nodes (train mode) - orig_output_nodes = [] - for n in reversed(graph_module.graph.nodes): - if n.op == "output": - orig_output_nodes.append(n) + orig_output_nodes = [n for n in reversed(graph_module.graph.nodes) if n.op == "output"] assert len(orig_output_nodes) for n in orig_output_nodes: graph_module.graph.erase_node(n) # Remove existing input nodes (train mode) - orig_input_nodes = [] - for n in reversed(graph_module.graph.nodes): - if n.op == "placeholder": - orig_input_nodes.append(n) + orig_input_nodes = [n for n in reversed(graph_module.graph.nodes) if n.op == "placeholder"] assert len(orig_input_nodes) # for n in orig_input_nodes: # n.users=()