Skip to content

Commit

Permalink
Clean up pyre issues in tests/helpers/influence/common.py (#1455)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1455

Fix some pyre issues in pytorch/captum/tests/helpers/influence/common.py

Reviewed By: cyrjano

Differential Revision: D66902046

fbshipit-source-id: 8b74f638a2060330e8665ff6103a6f255e1a205e
  • Loading branch information
jsawruk authored and facebook-github-bot committed Dec 12, 2024
1 parent 2a2b41d commit 2b9f4ae
Showing 1 changed file with 18 additions and 31 deletions.
49 changes: 18 additions & 31 deletions tests/helpers/influence/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,16 @@
from torch.utils.data import DataLoader, Dataset


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _isSorted(x, key=lambda x: x, descending=True):
def _isSorted(x, key=lambda x: x, descending=True) -> bool:
if descending:
return all([key(x[i]) >= key(x[i + 1]) for i in range(len(x) - 1)])
return all(key(x[i]) >= key(x[i + 1]) for i in range(len(x) - 1))
else:
return all([key(x[i]) <= key(x[i + 1]) for i in range(len(x) - 1)])
return all(key(x[i]) <= key(x[i + 1]) for i in range(len(x) - 1))


# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def _wrap_model_in_dataparallel(net):
def _wrap_model_in_dataparallel(net) -> Module:
alt_device_ids = [0] + [x for x in range(torch.cuda.device_count() - 1, 0, -1)]
net = net.cuda()
return torch.nn.DataParallel(net, device_ids=alt_device_ids)
Expand All @@ -60,9 +58,7 @@ def __init__(
def __len__(self) -> int:
return len(self.samples)

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tuple[Tensor, Tensor]:
return (self.samples[idx], self.labels[idx])


Expand All @@ -83,8 +79,7 @@ def __len__(self) -> int:
return len(self.samples[0])

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __getitem__(self, idx):
def __getitem__(self, idx: int):
"""
The signature of the returning item is: List[List], where the contents
are: [sample_0, sample_1, ...] + [labels] (two lists concacenated).
Expand All @@ -98,10 +93,8 @@ def __init__(
num_features: int,
use_gpu: bool = False,
) -> None:
# pyre-fixme[4]: Attribute must be annotated.
self.samples = torch.diag(torch.ones(num_features))
# pyre-fixme[4]: Attribute must be annotated.
self.labels = torch.zeros(num_features).unsqueeze(1)
self.samples: Tensor = torch.diag(torch.ones(num_features))
self.labels: Tensor = torch.zeros(num_features).unsqueeze(1)
if use_gpu:
self.samples = self.samples.cuda()
self.labels = self.labels.cuda()
Expand All @@ -115,23 +108,22 @@ def __init__(
num_features: int,
use_gpu: bool = False,
) -> None:
# pyre-fixme[4]: Attribute must be annotated.
self.samples = (
self.samples: Tensor = (
torch.arange(start=low, end=high, dtype=torch.float)
.repeat(num_features, 1)
.transpose(1, 0)
)
# pyre-fixme[4]: Attribute must be annotated.
self.labels = torch.arange(start=low, end=high, dtype=torch.float).unsqueeze(1)
self.labels: Tensor = torch.arange(
start=low, end=high, dtype=torch.float
).unsqueeze(1)
if use_gpu:
self.samples = self.samples.cuda()
self.labels = self.labels.cuda()


class BinaryDataset(ExplicitDataset):
def __init__(self, use_gpu: bool = False) -> None:
# pyre-fixme[4]: Attribute must be annotated.
self.samples = F.normalize(
self.samples: Tensor = F.normalize(
torch.stack(
(
torch.Tensor([1, 1]),
Expand Down Expand Up @@ -161,8 +153,7 @@ def __init__(self, use_gpu: bool = False) -> None:
)
)
)
# pyre-fixme[4]: Attribute must be annotated.
self.labels = torch.cat(
self.labels: Tensor = torch.cat(
(
torch.Tensor([1]).repeat(12, 1),
torch.Tensor([-1]).repeat(12, 1),
Expand Down Expand Up @@ -350,13 +341,10 @@ def get_random_model_and_data(
tmpdir,
# pyre-fixme[2]: Parameter must be annotated.
unpack_inputs,
# pyre-fixme[2]: Parameter must be annotated.
return_test_data=True,
return_test_data: bool = True,
gpu_setting: Optional[str] = None,
# pyre-fixme[2]: Parameter must be annotated.
return_hessian_data=False,
# pyre-fixme[2]: Parameter must be annotated.
model_type="random",
return_hessian_data: bool = False,
model_type: str = "random",
):
"""
returns a model, training data, and optionally data for computing the hessian
Expand Down Expand Up @@ -534,10 +522,9 @@ def generate_symmetric_matrix_given_eigenvalues(
return torch.matmul(Q, torch.matmul(torch.diag(torch.tensor(eigenvalues)), Q.T))


# pyre-fixme[3]: Return type must be annotated.
def generate_assymetric_matrix_given_eigenvalues(
eigenvalues: Union[Tensor, List[float]]
):
) -> Tensor:
"""
following https://github.com/google-research/jax-influence/blob/74bd321156b5445bb35b9594568e4eaaec1a76a3/jax_influence/test_utils.py#L105 # noqa: E501
generate assymetric random matrix with specified eigenvalues. this is used in
Expand Down

0 comments on commit 2b9f4ae

Please sign in to comment.