Skip to content
59 changes: 55 additions & 4 deletions tgrag/experiments/gnn_experiments/weak_supervision_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pandas as pd
import torch
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import degree
from tqdm import tqdm

from tgrag.dataset.temporal_dataset import TemporalDataset
Expand All @@ -16,6 +17,10 @@
from tgrag.utils.logger import setup_logging
from tgrag.utils.matching import reverse_domain
from tgrag.utils.path import get_root_dir, get_scratch
from tgrag.utils.plot import (
plot_neighbor_degree_distribution,
plot_neighbor_distribution,
)
from tgrag.utils.seed import seed_everything

parser = argparse.ArgumentParser(
Expand All @@ -34,6 +39,7 @@ def run_weak_supervision_forward(
model_arguments: ModelArguments,
dataset: TemporalDataset,
weight_directory: Path,
target: str,
) -> None:
root = get_root_dir()
phishing_dict: Dict[str, str] = {
Expand All @@ -42,6 +48,13 @@ def run_weak_supervision_forward(
'PhishTank': 'data/phishing_data/cc_dec_2024_phishtank_domains.csv',
}
data = dataset[0]

src, dst = data.edge_index
logging.info(f'Src, dst degrees loaded.')

out_degree = degree(src, num_nodes=data.num_nodes, dtype=torch.long)
in_degree = degree(dst, num_nodes=data.num_nodes, dtype=torch.long)

device = f'cuda:{model_arguments.device}' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
logging.info(f'Device found: {device}')
Expand Down Expand Up @@ -72,8 +85,8 @@ def run_weak_supervision_forward(
phishing_loader = NeighborLoader(
data,
input_nodes=phishing_indices,
num_neighbors=[30, 30, 30],
batch_size=1024,
num_neighbors=model_arguments.num_neighbors,
batch_size=model_arguments.batch_size,
shuffle=False,
)
logging.info(
Expand All @@ -82,17 +95,54 @@ def run_weak_supervision_forward(

num_nodes = data.num_nodes
all_preds = torch.zeros(num_nodes, 1)
neighbor_preds = []
neighbor_nodes = set()

with torch.no_grad():
for batch in tqdm(phishing_loader, desc=f'{dataset_name} batch'):
batch = batch.to(device)
preds = model(batch.x, batch.edge_index)
seed_nodes = batch.n_id[: batch.batch_size]

pred_neighbors = preds[batch.batch_size :]
neighbor_preds.append(pred_neighbors.cpu())
neighbor_nodes.update(batch.n_id[batch.batch_size :].tolist())

all_preds[seed_nodes] = preds[: batch.batch_size].cpu()

neighbor_preds = torch.cat(neighbor_preds, dim=0)
neighbor_nodes = torch.tensor(list(neighbor_nodes), dtype=torch.long)

neighbor_in_degree = in_degree[neighbor_nodes]
logging.info(f'Size of in-degree tensor: {neighbor_in_degree.size()}')
logging.info(f'Sample of in-degree: {neighbor_in_degree[:10]}')
neighbor_out_degree = out_degree[neighbor_nodes]
logging.info(f'Size of out-degree tensor: {neighbor_out_degree.size()}')
logging.info(f'Sample of out-degree: {neighbor_out_degree[:10]}')

plot_neighbor_distribution(
neighbor_preds=neighbor_preds,
dataset_name=dataset_name,
model_name=model_arguments.model,
target=target,
)
plot_neighbor_degree_distribution(
neighbor_degree=neighbor_in_degree,
dataset_name=dataset_name,
model_name=model_arguments.model,
target=target,
degree='In-degree',
)
plot_neighbor_degree_distribution(
neighbor_degree=neighbor_out_degree,
dataset_name=dataset_name,
model_name=model_arguments.model,
target=target,
degree='Out-degree',
)
logging.info(f'Saving distribution of {dataset_name}')
preds = all_preds[phishing_indices]
logging.info(f'Number of predictions: {preds.size()}')
logging.info(f'Predictions: {preds}')
for threshold in [0.1, 0.3, 0.5]:
upper = dataset_name == 'IP2Location'
accuracy = get_accuracy(preds, threshold=threshold, upper=upper)
Expand Down Expand Up @@ -145,7 +195,7 @@ def main() -> None:
encoding=encoding_dict,
seed=meta_args.global_seed,
processed_dir=f'{scratch}/{meta_args.processed_location}',
) # Map to .to_cpu()
)
logging.info('In-Memory Dataset loaded.')
weight_directory = (
root / cast(str, meta_args.weights_directory) / f'{meta_args.target_col}'
Expand All @@ -157,6 +207,7 @@ def main() -> None:
experiment_arg.model_args,
dataset,
weight_directory,
target=meta_args.target_col,
)


Expand Down
56 changes: 56 additions & 0 deletions tgrag/utils/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,3 +1015,59 @@ def plot_pred_target_distributions_histogram(
plt.tight_layout()
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1)
plt.close()


def plot_neighbor_distribution(
neighbor_preds: Tensor, dataset_name: str, model_name: str, target: str
) -> None:
root = get_root_dir()
save_dir = root / 'results' / 'plots' / model_name / 'distribution' / target
save_dir.mkdir(parents=True, exist_ok=True)
save_path = save_dir / f'{dataset_name}_neighbor_pred_distribution.png'
plt.figure(figsize=(6, 4))
plt.hist(neighbor_preds.numpy(), bins=20, range=(0, 1), edgecolor='black')
plt.title(f'Predicted Label Distribution (Neighbors) — {dataset_name}')
plt.xlabel('Predicted label (0, 1)')
plt.ylabel('Frequency')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(save_path)
plt.close()


def plot_neighbor_degree_distribution(
neighbor_degree: Tensor,
dataset_name: str,
model_name: str,
target: str,
degree: str,
) -> None:
root = get_root_dir()
save_dir = root / 'results' / 'plots' / model_name / 'distribution' / target
save_dir.mkdir(parents=True, exist_ok=True)
save_path = save_dir / f'{dataset_name}_neighbor_{degree}_degree_distribution.png'
plt.figure(figsize=(6, 4))

deg = neighbor_degree
deg = deg[deg > 0]

unique_deg, counts = torch.unique(deg, return_counts=True)

sorted_idx = torch.argsort(unique_deg)
unique_deg = unique_deg[sorted_idx]
counts = counts[sorted_idx]

plt.bar(
unique_deg.numpy(),
counts.numpy(),
width=0.8,
edgecolor='black',
align='center',
)
plt.title(f'{degree} Distribution (Neighbors) — {dataset_name}')
plt.xlabel(f'{degree}')
plt.ylabel('Frequency')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(save_path)
plt.close()