-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset_utils.py
More file actions
53 lines (42 loc) · 2.09 KB
/
dataset_utils.py
File metadata and controls
53 lines (42 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# vim:fenc=utf-8
import torch_geometric.transforms as T
import warnings
warnings.filterwarnings('ignore')
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.datasets import Amazon
from torch_geometric.datasets import WikipediaNetwork
from torch_geometric.datasets import Actor
from torch_geometric.datasets import WebKB
from torch_geometric.datasets import LINKXDataset
from torch_geometric.datasets import AmazonProducts
def DataLoader(name):
name = name.lower()
root_path = '/home/jayee/datasets/'
if name in ['cora', 'citeseer', 'pubmed']:
dataset = Planetoid(root_path, name, split='random', num_train_per_class=20, num_val=500, num_test=1000, transform=T.NormalizeFeatures())
elif name in ['computers', 'photo']:
dataset = Amazon(root_path, name, T.NormalizeFeatures())
elif name in ['chameleon', 'squirrel']:
# use everything from "geom_gcn_preprocess=False" and
# only the node label y from "geom_gcn_preprocess=True"
preProcDs = WikipediaNetwork(
root=root_path, name=name, geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
dataset = WikipediaNetwork(
root=root_path, name=name, geom_gcn_preprocess=True, transform=T.NormalizeFeatures())
data = dataset[0]
data.edge_index = preProcDs[0].edge_index
return dataset, data
elif name in ['film']:
dataset = Actor(root=root_path+'Actor', transform=T.NormalizeFeatures())
elif name in ['texas', 'cornell', 'wisconsin']:
dataset = WebKB(root=root_path, name=name, transform=T.NormalizeFeatures())
elif name in ["penn94", "reed98", "amherst41", "cornell5", "johnshopkins55", "genius"]:
dataset = LINKXDataset(root=root_path, name=name, transform=T.NormalizeFeatures())
elif name in ["amazonproducts"]:
dataset = AmazonProducts(root=root_path+'amazonproducts', transform=T.NormalizeFeatures())
else:
raise ValueError(f'dataset {name} not supported in dataloader')
return dataset, dataset[0]