Skip to content

Commit a9cabd9

Browse files
committed
init upload
0 parents  commit a9cabd9

28 files changed

+2345
-0
lines changed

README.md

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Deep Networks from the Principle of Rate Reduction
2+
This repository is the official implementation of the paper [Deep Networks from the Principle of Rate Reduction](https://arxiv.org/abs/2010.14765) (2021) by [Kwan Ho Ryan Chan](https://ryanchankh.github.io)* (UC Berkeley), [Yaodong Yu](https://yaodongyu.github.io/)* (UC Berkeley), [Chong You](https://sites.google.com/view/cyou)* (UC Berkeley), [Haozhi Qi](https://haozhi.io/) (UC Berkeley), John Wright (Columbia), and Yi Ma (UC Berkeley).
3+
4+
## What is ReduNet?
5+
ReduNet is a deep neural network construcuted naturally by deriving the gradients of the Maximal Coding Rate Reduction (MCR<sup>2</sup>) [1] objective. Every layer of this network can be interpreted based on its mathematical operations and the network collectively is trained in a feed-forward manner only. In addition, by imposing shift invariant properties to our network, the convolutional operator can be derived using only the data and MCR<sup>2</sup> objective function, hence making our network design principled and interpretable.
6+
7+
<p align="center">
8+
<img src="images/arch-redunet.jpg" width="350"\><br>
9+
Figure: Weights and operations for one layer of ReduNet
10+
</p>
11+
<p align="center">
12+
13+
[1] Yu, Yaodong, Kwan Ho Ryan Chan, Chong You, Chaobing Song, and Yi Ma. "[Learning diverse and discriminative representations via the principle of maximal coding rate reduction](https://proceedings.neurips.cc/paper/2020/file/6ad4174eba19ecb5fed17411a34ff5e6-Paper.pdf)" Advances in Neural Information Processing Systems 33 (2020).
14+
15+
## Requirements
16+
This codebase is written for `python3`. To install necessary python packages, run `conda create --name redunet_official --file requirements.txt`.
17+
18+
## Core Usage and Design
19+
The design of this repository aims to be easy-to-use and easy-to-intergrate to the current framework of your experiment, as long as it uses PyTorch. The `ReduNet` object inherents from `nn.Sequential`, and layers `ReduLayers`, such as `Vector`, `Fourier1D` and `Fourier2D` inherent from `nn.Module`. Loss functions are implemented in `loss.py`. Architectures and Dataset options are located in `load.py` file. Data objects and pre-set architectures are loaded in folders `dataset` and `architectures`. Feel free to add more based on the experiments you want to run. We have provided basic experiment setups, located in `train_<mode>.py` and `evaluate_<mode>.py`, where `<mode>` is the type of experiment. For utility functions, please check out `functional.py` or `utils.py`. Feel free to email us if there are any issues or suggestions.
20+
21+
22+
## Example: Forward Construction
23+
To train a ReduNet using forward construction, please checkout `train_forward.py`. For evaluating, please checkout `evaluate_forward.py`. For example, to train on 40-layer ReduNet on MNIST using 1000 samples per class, run:
24+
25+
```
26+
$ python3 train_forward.py --data mnistvector --arch layers50 --samples 1000
27+
```
28+
After training, you can evaluate the trained model using `evaluate_forward.py`, by running:
29+
30+
```
31+
$ python3 evaluate_forward.py --model_dir ./saved_models/forward/mnistvector+layers50/samples1000
32+
```
33+
, which will evaluate using all available training samples and testing samples. For more training and testing options, please checkout the file `train_forward.py` and `evaluate_forward.py`.
34+
35+
### Experiments in Paper
36+
For code used to generate experimental empirical results listed in our paper, please visit our other repository: [https://github.com/ryanchankh/redunet_paper](https://github.com/ryanchankh/redunet_paper)
37+
38+
## Reference
39+
For technical details and full experimental results, please check the [paper](https://arxiv.org/abs/2010.14765). Please consider citing our work if you find it helpful to yours:
40+
41+
```
42+
@article{chan2020deep,
43+
title={Deep networks from the principle of rate reduction},
44+
author={Chan, Kwan Ho Ryan and Yu, Yaodong and You, Chong and Qi, Haozhi and Wright, John and Ma, Yi},
45+
journal={arXiv preprint arXiv:2010.14765},
46+
year={2020}
47+
}
48+
```
49+
50+
## License and Contributing
51+
- This README is formatted based on [paperswithcode](https://github.com/paperswithcode/releasing-research-code).
52+
- Feel free to post issues via Github.
53+
54+
## Contact
55+
Please contact [[email protected]]([email protected]) and [[email protected]]([email protected]) if you have any question on the codes.

architectures/mnist/flatten.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from redunet import *
2+
3+
4+
5+
def flatten(layers, num_classes):
6+
net = ReduNet(
7+
*[Vector(eta=0.5,
8+
eps=0.1,
9+
lmbda=500,
10+
num_classes=num_classes,
11+
dimensions=784
12+
) for _ in range(layers)],
13+
)
14+
return net

architectures/mnist/lift2d.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from redunet import *
2+
3+
4+
5+
def lift2d(channels, layers, num_classes, seed=0):
6+
net = ReduNet(
7+
Lift2D(1, channels, 9, seed=seed),
8+
*[Fourier2D(eta=0.5,
9+
eps=0.1,
10+
lmbda=500,
11+
num_classes=num_classes,
12+
dimensions=(channels, 28, 28)
13+
) for _ in range(layers)],
14+
)
15+
return net

datasets/__init__.py

Whitespace-only changes.

datasets/mnist.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torchvision.datasets as datasets
2+
import torchvision.transforms as transforms
3+
from torch.utils.data import DataLoader
4+
from .utils_data import filter_class
5+
6+
7+
8+
9+
10+
11+
def mnist2d_10class(data_dir):
12+
transform = transforms.Compose([
13+
transforms.ToTensor(),
14+
])
15+
trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
16+
testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
17+
num_classes = 10
18+
return trainset, testset, num_classes
19+
20+
def mnist2d_5class(data_dir):
21+
transform = transforms.Compose([
22+
transforms.ToTensor(),
23+
])
24+
trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
25+
testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
26+
trainset, num_classes = filter_class(trainset, [0, 1, 2, 3, 4])
27+
testset, _ = filter_class(testset, [0, 1, 2, 3, 4])
28+
num_classes = 5
29+
return trainset, testset, num_classes
30+
31+
def mnist2d_2class(data_dir):
32+
transform = transforms.Compose([
33+
transforms.ToTensor(),
34+
])
35+
trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
36+
testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
37+
trainset, num_classes = filter_class(trainset, [0, 1])
38+
testset, _ = filter_class(testset, [0, 1])
39+
return trainset, testset, num_classes
40+
41+
def mnistvector_10class(data_dir):
42+
transform = transforms.Compose([
43+
transforms.ToTensor(),
44+
transforms.Lambda(lambda x: x.flatten())
45+
])
46+
trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
47+
testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
48+
num_classes = 10
49+
return trainset, testset, num_classes
50+
51+
def mnistvector_5class(data_dir):
52+
transform = transforms.Compose([
53+
transforms.ToTensor(),
54+
transforms.Lambda(lambda x: x.flatten())
55+
])
56+
trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
57+
testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
58+
trainset, num_classes = filter_class(trainset, [0, 1, 2, 3, 4])
59+
testset, _ = filter_class(testset, [0, 1, 2, 3, 4])
60+
return trainset, testset, num_classes
61+
62+
def mnistvector_2class(data_dir):
63+
transform = transforms.Compose([
64+
transforms.ToTensor(),
65+
transforms.Lambda(lambda x: x.flatten())
66+
])
67+
trainset = datasets.MNIST(data_dir, train=True, transform=transform, download=True)
68+
testset = datasets.MNIST(data_dir, train=False, transform=transform, download=True)
69+
trainset, num_classes = filter_class(trainset, [0, 1])
70+
testset, _ = filter_class(testset, [0, 1])
71+
return trainset, testset, num_classes
72+
73+
74+
if __name__ == '__main__':
75+
trainset, testset, num_classes = mnist2d_2class('./data/')
76+
trainloader = DataLoader(trainset, batch_size=trainset.data.shape[0])
77+
print(trainset)
78+
print(testset)
79+
print(num_classes)
80+
81+
batch_imgs, batch_lbls = next(iter(trainloader))
82+
print(batch_imgs.shape, batch_lbls.shape)
83+
print(batch_lbls.unique(return_counts=True))

datasets/utils_data.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import numpy as np
2+
import torch
3+
4+
5+
6+
def filter_class(dataset, classes):
7+
data, labels = dataset.data, dataset.targets
8+
if type(labels) == list:
9+
labels = torch.tensor(labels)
10+
data_filter = []
11+
labels_filter = []
12+
for _class in classes:
13+
idx = labels == _class
14+
data_filter.append(data[idx])
15+
labels_filter.append(labels[idx])
16+
if type(dataset.data) == np.ndarray:
17+
dataset.data = np.vstack(data_filter)
18+
dataset.targets = np.hstack(labels_filter)
19+
elif type(dataset.data) == torch.Tensor:
20+
dataset.data = torch.cat(data_filter)
21+
dataset.targets = torch.cat(labels_filter)
22+
else:
23+
raise TypeError('dataset.data type neither np.ndarray nor torch.Tensor')
24+
return dataset, len(classes)

evaluate.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import numpy as np
2+
import scipy.stats as sps
3+
import torch
4+
5+
from sklearn.svm import LinearSVC
6+
from sklearn.decomposition import PCA
7+
from sklearn.decomposition import TruncatedSVD
8+
from sklearn.linear_model import SGDClassifier
9+
from sklearn.svm import LinearSVC, SVC
10+
from sklearn.tree import DecisionTreeClassifier
11+
from sklearn.ensemble import RandomForestClassifier
12+
13+
import functional as F
14+
import utils
15+
16+
17+
18+
def evaluate(eval_dir, method, train_features, train_labels, test_features, test_labels, **kwargs):
19+
if method == 'svm':
20+
acc_train, acc_test = svm(train_features, train_labels, test_features, test_labels)
21+
elif method == 'knn':
22+
acc_train, acc_test = knn(train_features, train_labels, test_features, test_labels, **kwargs)
23+
elif method == 'nearsub':
24+
acc_train, acc_test = nearsub(train_features, train_labels, test_features, test_labels, **kwargs)
25+
elif method == 'nearsub_pca':
26+
acc_train, acc_test = knn(train_features, train_labels, test_features, test_labels, **kwargs)
27+
acc_dict = {'train': acc_train, 'test': acc_test}
28+
utils.save_params(eval_dir, acc_dict, name=f'acc_{method}')
29+
30+
def svm(train_features, train_labels, test_features, test_labels):
31+
svm = LinearSVC(verbose=0, random_state=10)
32+
svm.fit(train_features, train_labels)
33+
acc_train = svm.score(train_features, train_labels)
34+
acc_test = svm.score(test_features, test_labels)
35+
print("SVM: {}, {}".format(acc_train, acc_test))
36+
return acc_train, acc_test
37+
38+
# def knn(train_features, train_labels, test_features, test_labels, k=5):
39+
# sim_mat = train_features @ train_features.T
40+
# topk = torch.from_numpy(sim_mat).topk(k=k, dim=0)
41+
# topk_pred = train_labels[topk.indices]
42+
# test_pred = torch.tensor(topk_pred).mode(0).values.detach()
43+
# acc_train = compute_accuracy(test_pred.numpy(), train_labels)
44+
45+
# sim_mat = train_features @ test_features.T
46+
# topk = torch.from_numpy(sim_mat).topk(k=k, dim=0)
47+
# topk_pred = train_labels[topk.indices]
48+
# test_pred = torch.tensor(topk_pred).mode(0).values.detach()
49+
# acc_test = compute_accuracy(test_pred.numpy(), test_labels)
50+
# print("kNN: {}, {}".format(acc_train, acc_test))
51+
# return acc_train, acc_test
52+
53+
def knn(train_features, train_labels, test_features, test_labels, k=5):
54+
sim_mat = train_features @ train_features.T
55+
topk = sim_mat.topk(k=k, dim=0)
56+
topk_pred = train_labels[topk.indices]
57+
test_pred = topk_pred.mode(0).values.detach()
58+
acc_train = compute_accuracy(test_pred, train_labels)
59+
60+
sim_mat = train_features @ test_features.T
61+
topk = sim_mat.topk(k=k, dim=0)
62+
topk_pred = train_labels[topk.indices]
63+
test_pred = topk_pred.mode(0).values.detach()
64+
acc_test = compute_accuracy(test_pred, test_labels)
65+
print("kNN: {}, {}".format(acc_train, acc_test))
66+
return acc_train, acc_test
67+
68+
# # TODO: 1. implement pytorch version 2. suport batches
69+
# def nearsub(train_features, train_labels, test_features, test_labels, num_classes, n_comp=10, return_pred=False):
70+
# train_scores, test_scores = [], []
71+
# classes = np.arange(num_classes)
72+
# features_sort, _ = utils.sort_dataset(train_features, train_labels,
73+
# classes=classes, stack=False)
74+
# fd = features_sort[0].shape[1]
75+
# if n_comp >= fd:
76+
# n_comp = fd - 1
77+
# for j in classes:
78+
# svd = TruncatedSVD(n_components=n_comp).fit(features_sort[j])
79+
# subspace_j = np.eye(fd) - svd.components_.T @ svd.components_
80+
# train_j = subspace_j @ train_features.T
81+
# test_j = subspace_j @ test_features.T
82+
# train_scores_j = np.linalg.norm(train_j, ord=2, axis=0)
83+
# test_scores_j = np.linalg.norm(test_j, ord=2, axis=0)
84+
# train_scores.append(train_scores_j)
85+
# test_scores.append(test_scores_j)
86+
# train_pred = np.argmin(train_scores, axis=0)
87+
# test_pred = np.argmin(test_scores, axis=0)
88+
# if return_pred:
89+
# return train_pred.tolist(), test_pred.tolist()
90+
# train_acc = compute_accuracy(classes[train_pred], train_labels)
91+
# test_acc = compute_accuracy(classes[test_pred], test_labels)
92+
# print('SVD: {}, {}'.format(train_acc, test_acc))
93+
# return train_acc, test_acc
94+
95+
def nearsub(train_features, train_labels, test_features, test_labels,
96+
num_classes, n_comp=10, return_pred=False):
97+
train_scores, test_scores = [], []
98+
classes = np.arange(num_classes)
99+
features_sort, _ = utils.sort_dataset(train_features, train_labels,
100+
classes=classes, stack=False)
101+
fd = features_sort[0].shape[1]
102+
for j in classes:
103+
_, _, V = torch.svd(features_sort[j])
104+
components = V[:, :n_comp].T
105+
subspace_j = torch.eye(fd) - components.T @ components
106+
train_j = subspace_j @ train_features.T
107+
test_j = subspace_j @ test_features.T
108+
train_scores_j = torch.linalg.norm(train_j, ord=2, axis=0)
109+
test_scores_j = torch.linalg.norm(test_j, ord=2, axis=0)
110+
train_scores.append(train_scores_j)
111+
test_scores.append(test_scores_j)
112+
train_pred = torch.stack(train_scores).argmin(0)
113+
test_pred = torch.stack(test_scores).argmin(0)
114+
if return_pred:
115+
return train_pred.numpy(), test_pred.numpy()
116+
train_acc = compute_accuracy(classes[train_pred], train_labels.numpy())
117+
test_acc = compute_accuracy(classes[test_pred], test_labels.numpy())
118+
print('SVD: {}, {}'.format(train_acc, test_acc))
119+
return train_acc, test_acc
120+
121+
def nearsub_pca(train_features, train_labels, test_features, test_labels, num_classes, n_comp=10):
122+
scores_pca = []
123+
classes = np.arange(num_classes)
124+
features_sort, _ = utils.sort_dataset(train_features, train_labels, classes=classes, stack=False)
125+
fd = features_sort[0].shape[1]
126+
if n_comp >= fd:
127+
n_comp = fd - 1
128+
for j in np.arange(len(classes)):
129+
pca = PCA(n_components=n_comp).fit(features_sort[j])
130+
pca_subspace = pca.components_.T
131+
mean = np.mean(features_sort[j], axis=0)
132+
pca_j = (np.eye(fd) - pca_subspace @ pca_subspace.T) \
133+
@ (test_features - mean).T
134+
score_pca_j = np.linalg.norm(pca_j, ord=2, axis=0)
135+
scores_pca.append(score_pca_j)
136+
test_predict_pca = np.argmin(scores_pca, axis=0)
137+
acc_pca = compute_accuracy(classes[test_predict_pca], test_labels)
138+
print('PCA: {}'.format(acc_pca))
139+
return acc_pca
140+
141+
def argmax(train_features, train_labels, test_features, test_labels):
142+
train_pred = train_features.argmax(1)
143+
train_acc = compute_accuracy(train_pred, train_labels)
144+
test_pred = test_features.argmax(1)
145+
test_acc = compute_accuracy(test_pred, test_labels)
146+
return train_acc, test_acc
147+
148+
def compute_accuracy(y_pred, y_true):
149+
"""Compute accuracy by counting correct classification. """
150+
assert y_pred.shape == y_true.shape
151+
if type(y_pred) == torch.Tensor:
152+
n_wrong = torch.count_nonzero(y_pred - y_true).item()
153+
elif type(y_pred) == np.ndarray:
154+
n_wrong = np.count_nonzero(y_pred - y_true)
155+
else:
156+
raise TypeError("Not Tensor nor Array type.")
157+
n_samples = len(y_pred)
158+
return 1 - n_wrong / n_samples
159+
160+
def baseline(train_features, train_labels, test_features, test_labels):
161+
test_models = {'log_l2': SGDClassifier(loss='log', max_iter=10000, random_state=42),
162+
'SVM_linear': LinearSVC(max_iter=10000, random_state=42),
163+
'SVM_RBF': SVC(kernel='rbf', random_state=42),
164+
'DecisionTree': DecisionTreeClassifier(),
165+
'RandomForrest': RandomForestClassifier()}
166+
for model_name in test_models:
167+
test_model = test_models[model_name]
168+
test_model.fit(train_features, train_labels)
169+
score = test_model.score(test_features, test_labels)
170+
print(f"{model_name}: {score}")
171+
172+
def majority_vote(pred, true):
173+
pred_majority = sps.mode(pred, axis=0)[0].squeeze()
174+
return compute_accuracy(pred_majority, true)

0 commit comments

Comments
 (0)