-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaugmentor.py
More file actions
115 lines (87 loc) · 3.46 KB
/
augmentor.py
File metadata and controls
115 lines (87 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import numpy as np
import torch_geometric.utils._subgraph as _subgraph
from torch_geometric.data import Data
from copy import deepcopy
def drop_nodes(data, drop_rate=0.1):
data = deepcopy(data)
node_num, _ = data.x.size()
# 计算要删除的总节点数
total_drop_num = int(node_num * drop_rate)
dropped_nodes = set()
while len(dropped_nodes) < total_drop_num:
# 计算每个节点的出度
out_degrees = torch.zeros(node_num, dtype=torch.long)
for _, v in data.edge_index.t():
if v.item() not in dropped_nodes: # 只考虑未被删除的节点
out_degrees[v] += 1
# 找出所有出度为0的节点
leaf_nodes = torch.where(out_degrees == 0)[0]
leaf_nodes = leaf_nodes.tolist()
# 过滤掉已经删除的节点
leaf_nodes = [n for n in leaf_nodes if n not in dropped_nodes]
# 如果没有新的叶子节点,退出循环
if not leaf_nodes:
break
# 计算这一轮要删除的节点数
remaining_drop = total_drop_num - len(dropped_nodes)
drop_num = min(len(leaf_nodes), remaining_drop)
# 随机选择要删除的叶子节点
idx_drop = np.random.choice(leaf_nodes, drop_num, replace=False)
dropped_nodes.update(idx_drop)
# 如果没有删除任何节点,直接返回
if not dropped_nodes:
return data
# 选择要保留的节点
keep_nodes = torch.ones(node_num, dtype=torch.bool)
for node in dropped_nodes:
keep_nodes[node] = False
keep_nodes = torch.where(keep_nodes)[0]
# 使用subgraph函数提取子图
edge_index, edge_attr = _subgraph.subgraph(
keep_nodes,
data.edge_index,
data.edge_attr if 'edge_attr' in data else None,
relabel_nodes=True
)
# 更新数据
data.edge_index = edge_index
if 'edge_attr' in data:
data.edge_attr = edge_attr
data.x = data.x[keep_nodes]
if 'y' in data:
data.y = 1 - data.y
return data
def permute_edges(data, permute_ratio=0.1):
data = deepcopy(data)
_, edge_num = data.edge_index.size()
permute_num = int(edge_num * permute_ratio)
edge_index = data.edge_index.transpose(0, 1).numpy()
idx_nonpermute = np.random.choice(edge_num, edge_num - permute_num, replace=False)
edge_index = edge_index[idx_nonpermute]
data.edge_index = torch.tensor(edge_index).transpose_(0, 1)
if 'edge_attr' in data:
data.edge_attr = data.edge_attr[idx_nonpermute]
if 'y' in data:
data.y = 1 - data.y
return data
def mask_nodes(data, mask_rate=0.1):
data = deepcopy(data)
node_num, feat_dim = data.x.size()
mask_num = int(node_num * mask_rate)
max_call_times = int(torch.max(data.x).item())
idx_mask = np.random.choice(node_num, mask_num, replace=False)
# 提取这几行的 768: 部分(768维是词嵌入, 后续是PTE时间特征)
tail_part = data.x[idx_mask][:, 768:]
# 打乱这些行的顺序
shuffled_indices = torch.randperm(mask_num)
shuffled_tail = tail_part[shuffled_indices]
# 把打乱后的数据写回去
data.x[idx_mask, 768:] = shuffled_tail
if 'y' in data:
data.y = 1 - data.y
return data
def random_aug(data, aug_rate=0.2):
aug_funcs = [drop_nodes, mask_nodes]
n_funcs = len(aug_funcs)
return aug_funcs[np.random.randint(n_funcs)](data, aug_rate)