-
Notifications
You must be signed in to change notification settings - Fork 0
/
tensor_utils.py
107 lines (80 loc) · 2.92 KB
/
tensor_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
def difference(t1, t2):
""" Returns all elements that are only in t1 or t2.
"""
t1, t2 = t1.unique(), t2.unique()
combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
diff = uniques[counts == 1]
return diff
def intersection(t1, t2):
""" Returns all elements that are in both t1 and t2.
"""
t1, t2 = t1.unique(), t2.unique()
combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
intersec = uniques[counts > 1]
return intersec
def setdiff(t1, t2):
""" Returns all elements of tensor t1 that are not in tensor t2.
"""
diff = difference(t1, t2)
diff_from_t1 = intersection(diff, t1)
return diff_from_t1
def concatenate_1d(*tensors):
"""Concatenates the given 1d tensors.
"""
for tensor in tensors:
if len(tensor.size()) != 1:
raise ValueError("Can only concatenate 1d tensors. Otherwise, use rbind / cbind.")
return torch.cat(tensors, 0)
def cbind(*tensors):
"""Combines the given 2d tensors as columns.
"""
# If a vector is one-dimensional, convert it to a two-dimensional column
# vector.
tensors = [unsqueeze_to_2d(var) if len(var.size()) == 1 else var for var in tensors]
return torch.cat(tensors, 1)
def rbind(*tensors):
"""Combines the given 2d tensors as rows.
"""
for tensor in tensors:
if len(tensor.size()) < 2:
raise ValueError("rbind only takes two-dimensional tensors as input")
return torch.cat(tensors, 0)
def unsqueeze_to_1d(*tensors):
""" Unsqueezes zero-dimensional tensors to one-dimensional tensors.
"""
if len(tensors) > 1:
return [var.unsqueeze(dim=0) if len(var.size()) == 0 else var for var in tensors]
else:
return tensors[0].unsqueeze(dim=0) if len(tensors[0].size()) == 0 else tensors[0]
def unsqueeze_to_2d(*tensors):
""" Unsqueezes one-dimensional tensors two-dimensional tensors.
"""
if len(tensors) > 1:
return [var.unsqueeze(dim=1) if len(var.size()) == 1 else var for var in tensors]
else:
return tensors[0].unsqueeze(dim=1) if len(tensors[0].size()) == 1 else tensors[0]
def convert_type(torch_type, *tensors):
""" Converts all given tensors to tensors of the given type.
"""
if len(tensors) > 1:
return [var.to(torch_type) for var in tensors]
else:
return tensors[0].to(torch_type)
def shuffle(t, mode=None):
""" Shuffles the rows of the given tensor.
"""
if mode == "within_columns":
rand_indices = torch.randperm(len(t))
t = t[rand_indices]
elif mode == "within_rows":
t = t.transpose(dim0=0, dim1=1)
rand_indices = torch.randperm(len(t))
t = t[rand_indices]
t = t.transpose(dim0=0, dim1=1)
else:
rand_indices = torch.randperm(t.nelement())
t = t.view(-1)[rand_indices].view(t.size())
return t