Skip to content

Commit 6eee8f7

Browse files
denklegithub-actions[bot]mikeheddes
authored
Add UCI classification benchmark and example using intRVFL (#98)
* Draft first version of intRVFL classification * [github-action] formatting fixes * Reworked code targeting modularity but still there is a number of issues to discuss * [github-action] formatting fixes * Revised code to improve modularity * [github-action] formatting fixes * [github-action] formatting fixes * Further improve the code logic and allocation accross the library * [github-action] formatting fixes * Allow more types of models in Density encoding * Refactor benchmark * Refactor benchmark example * [github-action] formatting fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: mikeheddes <mikeheddes@gmail.com>
1 parent 279e2ef commit 6eee8f7

File tree

7 files changed

+626
-4
lines changed

7 files changed

+626
-4
lines changed

docs/datasets.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ Base classes
148148
:toctree: generated/
149149
:template: class_dataset.rst
150150

151+
UCIClassificationBenchmark
151152
CollectionDataset
152153
DatasetFourFold
153154
DatasetTrainTest

docs/embeddings.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@ torchhd.embeddings
1616
Thermometer
1717
Circular
1818
Projection
19-
Sinusoid
19+
Sinusoid
20+
Density

examples/UCI_benchmark_intRVFL.py

Lines changed: 326 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,326 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.utils.data as data
4+
from torch import Tensor
5+
from tqdm import tqdm
6+
7+
# Note: this example requires the torchmetrics library: https://torchmetrics.readthedocs.io
8+
import torchmetrics
9+
import torchhd
10+
from torchhd.datasets import UCIClassificationBenchmark
11+
12+
13+
# Function for performing min-max normalization of the input data samples
14+
def create_min_max_normalize(min: Tensor, max: Tensor):
15+
def normalize(input: Tensor) -> Tensor:
16+
return torch.nan_to_num((input - min) / (max - min))
17+
18+
return normalize
19+
20+
21+
# Function that forms the classifier (readout matrix) with the ridge regression
22+
def classifier_ridge_regression(
23+
train_ld: data.DataLoader,
24+
dimensions: int,
25+
num_classes: int,
26+
lamb: float,
27+
encoding_function,
28+
data_type: torch.dtype,
29+
device: torch.device,
30+
):
31+
32+
# Get number of training samples
33+
num_train = len(train_ld.dataset)
34+
# Collects high-dimensional represetations of data in the train data
35+
total_samples_hv = torch.zeros(
36+
num_train,
37+
dimensions,
38+
dtype=data_type,
39+
device=device,
40+
)
41+
# Collects one-hot encodings of class labels
42+
labels_one_hot = torch.zeros(
43+
num_train,
44+
num_classes,
45+
dtype=data_type,
46+
device=device,
47+
)
48+
49+
with torch.no_grad():
50+
count = 0
51+
for samples, labels in tqdm(train_ld, desc="Training"):
52+
53+
samples = samples.to(device)
54+
labels = labels.to(device)
55+
# Make one-hot encoding
56+
labels_one_hot[torch.arange(count, count + samples.size(0)), labels] = 1
57+
58+
# Make transformation into high-dimensional space
59+
samples_hv = encoding_function(samples)
60+
total_samples_hv[count : count + samples.size(0), :] = samples_hv
61+
62+
count += samples.size(0)
63+
64+
# Compute the readout matrix using the ridge regression
65+
Wout = (
66+
torch.t(labels_one_hot)
67+
@ total_samples_hv
68+
@ torch.linalg.pinv(
69+
torch.t(total_samples_hv) @ total_samples_hv
70+
+ lamb * torch.diag(torch.var(total_samples_hv, 0))
71+
)
72+
)
73+
74+
return Wout
75+
76+
77+
# Specify a model to be evaluated
78+
class IntRVFLRidge(nn.Module):
79+
"""Class implementing integer random vector functional link network (intRVFL) model as described in `Density Encoding Enables Resource-Efficient Randomly Connected Neural Networks <https://doi.org/10.1109/TNNLS.2020.3015971>`_.
80+
81+
Args:
82+
dataset (torchhd.datasets.CollectionDataset): Specifies a dataset to be evaluted by the model.
83+
num_feat (int): Number of features in the dataset.
84+
device (torch.device, optional): Specifies device to be used for Torch.
85+
"""
86+
87+
# These values of hyperparameters were found via the grid search for intRVFL model as described in the article.
88+
INT_RVFL_HYPER = {
89+
"abalone": (1450, 32, 15),
90+
"acute-inflammation": (50, 0.0009765625, 1),
91+
"acute-nephritis": (50, 0.0009765625, 1),
92+
"adult": (1150, 0.0625, 3),
93+
"annealing": (1150, 0.015625, 7),
94+
"arrhythmia": (1400, 0.0009765625, 7),
95+
"audiology-std": (950, 16, 3),
96+
"balance-scale": (50, 32, 7),
97+
"balloons": (50, 0.0009765625, 1),
98+
"bank": (200, 0.001953125, 7),
99+
"blood": (50, 16, 7),
100+
"breast-cancer": (50, 32, 1),
101+
"breast-cancer-wisc": (650, 16, 3),
102+
"breast-cancer-wisc-diag": (1500, 2, 3),
103+
"breast-cancer-wisc-prog": (1450, 0.01562500, 3),
104+
"breast-tissue": (1300, 0.1250000, 1),
105+
"car": (250, 32, 3),
106+
"cardiotocography-10clases": (1350, 0.0009765625, 3),
107+
"cardiotocography-3clases": (900, 0.007812500, 15),
108+
"chess-krvk": (800, 4, 1),
109+
"chess-krvkp": (1350, 0.01562500, 3),
110+
"congressional-voting": (100, 32, 15),
111+
"conn-bench-sonar-mines-rocks": (1100, 0.01562500, 3),
112+
"conn-bench-vowel-deterding": (1350, 8, 3),
113+
"connect-4": (1100, 0.5, 3),
114+
"contrac": (50, 8, 7),
115+
"credit-approval": (200, 32, 7),
116+
"cylinder-bands": (1100, 0.0009765625, 7),
117+
"dermatology": (900, 8, 3),
118+
"echocardiogram": (250, 32, 15),
119+
"ecoli": (350, 32, 3),
120+
"energy-y1": (650, 0.1250000, 3),
121+
"energy-y2": (1000, 0.0625, 7),
122+
"fertility": (150, 32, 7),
123+
"flags": (900, 32, 15),
124+
"glass": (1400, 0.03125000, 3),
125+
"haberman-survival": (100, 32, 3),
126+
"hayes-roth": (50, 16, 1),
127+
"heart-cleveland": (50, 32, 15),
128+
"heart-hungarian": (50, 16, 15),
129+
"heart-switzerland": (50, 8, 15),
130+
"heart-va": (1350, 0.1250000, 15),
131+
"hepatitis": (1300, 0.03125000, 1),
132+
"hill-valley": (150, 0.01562500, 1),
133+
"horse-colic": (850, 32, 1),
134+
"ilpd-indian-liver": (1200, 0.25, 7),
135+
"image-segmentation": (650, 8, 1),
136+
"ionosphere": (1150, 0.001953125, 1),
137+
"iris": (50, 4, 3),
138+
"led-display": (50, 0.0009765625, 7),
139+
"lenses": (50, 0.03125000, 1),
140+
"letter": (1500, 32, 1),
141+
"libras": (1250, 0.1250000, 3),
142+
"low-res-spect": (1400, 8, 7),
143+
"lung-cancer": (450, 0.0009765625, 1),
144+
"lymphography": (1150, 32, 1),
145+
"magic": (800, 16, 3),
146+
"mammographic": (150, 16, 7),
147+
"miniboone": (650, 0.0625, 15),
148+
"molec-biol-promoter": (1250, 32, 1),
149+
"molec-biol-splice": (1000, 8, 15),
150+
"monks-1": (50, 4, 3),
151+
"monks-2": (400, 32, 1),
152+
"monks-3": (50, 4, 15),
153+
"mushroom": (150, 0.25, 3),
154+
"musk-1": (1300, 0.001953125, 7),
155+
"musk-2": (1150, 0.007812500, 7),
156+
"nursery": (1000, 32, 3),
157+
"oocytes_merluccius_nucleus_4d": (1500, 1, 7),
158+
"oocytes_merluccius_states_2f": (1500, 0.0625, 7),
159+
"oocytes_trisopterus_nucleus_2f": (1450, 0.003906250, 3),
160+
"oocytes_trisopterus_states_5b": (1450, 2, 7),
161+
"optical": (1100, 32, 7),
162+
"ozone": (50, 0.003906250, 1),
163+
"page-blocks": (800, 0.001953125, 1),
164+
"parkinsons": (1200, 0.5, 1),
165+
"pendigits": (1500, 0.1250000, 1),
166+
"pima": (50, 32, 1),
167+
"pittsburg-bridges-MATERIAL": (100, 8, 1),
168+
"pittsburg-bridges-REL-L": (1200, 0.5, 1),
169+
"pittsburg-bridges-SPAN": (450, 4, 7),
170+
"pittsburg-bridges-T-OR-D": (1000, 16, 1),
171+
"pittsburg-bridges-TYPE": (50, 32, 7),
172+
"planning": (50, 32, 1),
173+
"plant-margin": (1350, 2, 7),
174+
"plant-shape": (1450, 0.25, 3),
175+
"plant-texture": (1500, 4, 7),
176+
"post-operative": (50, 32, 15),
177+
"primary-tumor": (950, 32, 3),
178+
"ringnorm": (1500, 0.125, 3),
179+
"seeds": (550, 32, 1),
180+
"semeion": (1400, 32, 15),
181+
"soybean": (850, 1, 3),
182+
"spambase": (1350, 0.0078125, 15),
183+
"spect": (50, 32, 1),
184+
"spectf": (1100, 0.25, 15),
185+
"statlog-australian-credit": (200, 32, 15),
186+
"statlog-german-credit": (500, 32, 15),
187+
"statlog-heart": (50, 32, 7),
188+
"statlog-image": (950, 0.125, 1),
189+
"statlog-landsat": (1500, 16, 3),
190+
"statlog-shuttle": (100, 0.125, 7),
191+
"statlog-vehicle": (1450, 0.125, 7),
192+
"steel-plates": (1500, 0.0078125, 3),
193+
"synthetic-control": (1350, 16, 3),
194+
"teaching": (400, 32, 3),
195+
"thyroid": (300, 0.001953125, 7),
196+
"tic-tac-toe": (750, 8, 1),
197+
"titanic": (50, 0.0009765625, 1),
198+
"trains": (100, 16, 1),
199+
"twonorm": (1100, 0.0078125, 15),
200+
"vertebral-column-2clases": (250, 32, 3),
201+
"vertebral-column-3clases": (200, 32, 15),
202+
"wall-following": (1200, 0.00390625, 3),
203+
"waveform": (1400, 8, 7),
204+
"waveform-noise": (1300, 0.0009765625, 15),
205+
"wine": (850, 32, 1),
206+
"wine-quality-red": (1100, 32, 1),
207+
"wine-quality-white": (950, 8, 3),
208+
"yeast": (1350, 4, 1),
209+
"zoo": (400, 8, 7),
210+
}
211+
212+
def __init__(
213+
self,
214+
dataset: torchhd.datasets.CollectionDataset,
215+
num_feat: int,
216+
device: torch.device = None,
217+
):
218+
super(IntRVFLRidge, self).__init__()
219+
self.device = device
220+
self.num_feat = num_feat
221+
# Fetch the hyperparameters for the corresponding dataset
222+
hyper_param = self.INT_RVFL_HYPER[dataset.name]
223+
# Dimensionality of vectors used when transforming input data
224+
self.dimensions = hyper_param[0]
225+
# Regularization parameter used for ridge regression classifier
226+
self.lamb = hyper_param[1]
227+
# Parameter of the clipping function used as the part of transforming input data
228+
self.kappa = hyper_param[2]
229+
# Number of classes in the dataset
230+
self.num_classes = len(dataset.classes)
231+
# Initialize the classifier
232+
self.classify = nn.Linear(self.dimensions, self.num_classes, bias=False)
233+
self.classify.weight.data.fill_(0.0)
234+
# Set up the encoding for the model as specified in "Density"
235+
self.hypervector_encoding = torchhd.embeddings.Density(
236+
self.num_feat, self.dimensions
237+
)
238+
239+
# Specify encoding function for data samples
240+
def encode(self, x):
241+
return self.hypervector_encoding(x).clipping(self.kappa)
242+
243+
# Specify how to make an inference step and issue a prediction
244+
def forward(self, x):
245+
# Make encodings for all data samples in the batch
246+
encodings = self.encode(x)
247+
# Get similarity values for each class assuming implicitly that there is only one prototype per class. This does not have to be the case in general.
248+
logit = self.classify(encodings)
249+
# Form predictions
250+
predictions = torch.argmax(logit, dim=-1)
251+
return predictions
252+
253+
# Train the classfier
254+
def fit(
255+
self,
256+
train_ld: data.DataLoader,
257+
):
258+
# Gets classifier (readout matrix) via the ridge regression
259+
Wout = classifier_ridge_regression(
260+
train_ld,
261+
self.dimensions,
262+
self.num_classes,
263+
self.lamb,
264+
self.encode,
265+
self.hypervector_encoding.key.weight.dtype,
266+
self.device,
267+
)
268+
# Assign the obtained classifier to the output
269+
with torch.no_grad():
270+
self.classify.weight.copy_(Wout)
271+
272+
273+
# Specify device to be used for Torch.
274+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
275+
print("Using {} device".format(device))
276+
# Specifies batch size to be used for the model.
277+
batch_size = 10
278+
# Specifies how many random initializations of the model to evaluate for each dataset in the collection.
279+
repeats = 5
280+
281+
282+
# Get an instance of the UCI benchmark
283+
benchmark = UCIClassificationBenchmark("../data", download=True)
284+
# Perform evaluation
285+
for dataset in benchmark.datasets():
286+
print(dataset.name)
287+
288+
# Number of features in the dataset.
289+
num_feat = dataset.train[0][0].size(-1)
290+
# Number of classes in the dataset.
291+
num_classes = len(dataset.train.classes)
292+
293+
# Get values for min-max normalization and add the transformation
294+
min_val = torch.min(dataset.train.data, 0).values.to(device)
295+
max_val = torch.max(dataset.train.data, 0).values.to(device)
296+
transform = create_min_max_normalize(min_val, max_val)
297+
dataset.train.transform = transform
298+
dataset.test.transform = transform
299+
300+
# Set up data loaders
301+
train_loader = data.DataLoader(dataset.train, batch_size=batch_size, shuffle=True)
302+
test_loader = data.DataLoader(dataset.test, batch_size=batch_size)
303+
304+
# Run for the requested number of simulations
305+
for r in range(repeats):
306+
# Creates a model to be evaluated. The model should specify both transformation of input data as weel as the algortihm for forming the classifier.
307+
model = IntRVFLRidge(
308+
getattr(torchhd.datasets, dataset.name), num_feat, device
309+
).to(device)
310+
311+
# Obtain the classifier for the model
312+
model.fit(train_loader)
313+
accuracy = torchmetrics.Accuracy("multiclass", num_classes=num_classes)
314+
315+
with torch.no_grad():
316+
for samples, targets in tqdm(test_loader, desc="Testing"):
317+
samples = samples.to(device)
318+
# Make prediction
319+
predictions = model(samples)
320+
accuracy.update(predictions.cpu(), targets)
321+
322+
benchmark.report(dataset, accuracy.compute().item())
323+
324+
# Returns a dictionary with names of the datasets and their respective accuracy that is averaged over folds (if applicable) and repeats
325+
benchmark_accuracy = benchmark.score()
326+
print(benchmark_accuracy)

torchhd/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torchhd.datasets.dataset import CollectionDataset
1010
from torchhd.datasets.dataset import DatasetFourFold
1111
from torchhd.datasets.dataset import DatasetTrainTest
12+
from torchhd.datasets.dataset import UCIClassificationBenchmark
1213
from torchhd.datasets.abalone import Abalone
1314
from torchhd.datasets.adult import Adult
1415
from torchhd.datasets.acute_inflammation import AcuteInflammation
@@ -143,6 +144,7 @@
143144
"CollectionDataset",
144145
"DatasetFourFold",
145146
"DatasetTrainTest",
147+
"UCIClassificationBenchmark",
146148
"Abalone",
147149
"Adult",
148150
"AcuteInflammation",

0 commit comments

Comments
 (0)