diff --git a/README.md b/README.md index 4025f26..8b0ea1a 100644 --- a/README.md +++ b/README.md @@ -85,10 +85,12 @@ run it The classifier supports several command line options for training configuration: ### Training Parameters -- `--learningRate`: Learning rate for training (default: 1e-4) -- `--batchSize`: Batch size for training (default: 8) -- `--epochs`: Number of training epochs (default: 100) -- `--minTrainingLoss`: Minimum reduction in training loss in orders of magnitude (default: 2, set to 0 to disable check) +- `--learningRate`: Learning rate for training (default: 1e-5) +- `--weightDecay`: Weight decay for L2 regularization (default: 5e-4) +- `--dropoutRate`: Dropout rate for regularization (default: 0.3) +- `--batchSize`: Batch size for training (default: 1) +- `--epochs`: Number of training epochs (default: 2000) +- `--minTrainingLoss`: Minimum reduction in training loss in orders of magnitude (default: 3, set to 0 to disable check) ### Data Configuration - `--trainFrameFirst`: First frame number for training data (default: 1) @@ -102,11 +104,13 @@ The classifier supports several command line options for training configuration: - `--use-amp`: Enable automatic mixed precision training for faster training on modern GPUs - `--amp-dtype`: Data type for mixed precision (`float16` or `bfloat16`, default: `bfloat16`) - `--patience`: Patience for early stopping (default: 15 epochs) +- `--seed`: Random seed for reproducibility (default: None for non-deterministic) +- `--require-gpu`: Require GPU to be available, exit if not found ### Output and Monitoring - `--plot`: Enable creation of figures showing ground truth and model-identified X-points - `--plotDir`: Directory where figures are written (default: `./plots`) -- `--checkPointFrequency`: Number of epochs between model checkpoints (default: 10) +- `--checkPointFrequency`: Number of epochs between model checkpoints (default: 100) ### Performance Benchmarking - `--benchmark`: Enable performance benchmarking (tracks timing, throughput, GPU memory) @@ -118,18 +122,21 @@ The classifier supports several command line options for training configuration: ### Example with Advanced Options -For faster training with mixed precision and early stopping: - +For training with custom regularization and reproducibility: ```bash python -u ${rcRoot}/reconClassifier/XPointMLTest.py \ --paramFile=/path/to/params.txt \ --xptCacheDir=/path/to/cache \ --epochs 200 \ --learningRate 1e-4 \ +--weightDecay 1e-3 \ +--dropoutRate 0.3 \ --batchSize 16 \ --use-amp \ --amp-dtype bfloat16 \ --patience 20 \ +--seed 42 \ +--require-gpu \ --plot \ --trainFrameLast 100 \ --validationFrameLast 120 diff --git a/XPointMLTest.py b/XPointMLTest.py index 22f1a69..3027485 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -31,6 +31,34 @@ # Import evaluation metrics module from eval_metrics import ModelEvaluator, evaluate_model_on_dataset +def set_seed(seed): + """ + Set random seed for reproducibility across all libraries + + Parameters: + seed (int): Random seed value + """ + if seed is None: + return + + print(f"Setting random seed to {seed} for reproducibility") + + # Python random + import random + random.seed(seed) + + # NumPy + np.random.seed(seed) + + # PyTorch + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # for multi-GPU + + # Make PyTorch deterministic (may reduce performance slightly) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + def expand_xpoints_mask(binary_mask, kernel_size=9): """ Expands each X-point in a binary mask to include surrounding cells @@ -184,6 +212,7 @@ def __init__(self, paramFile, fnumList, xptCacheDir=None, """ paramFile: Path to parameter file (string). fnumList: List of frames to iterate. + rotateAndReflect: If True, creates static augmented copies (deprecated, use on-the-fly instead) """ super().__init__() self.paramFile = paramFile @@ -299,13 +328,35 @@ def load(self, fnum): } class XPointPatchDataset(Dataset): - """On‑the‑fly square crops, balancing positive / background patches.""" - def __init__(self, base_ds, patch=64, pos_ratio=0.5, retries=30): + """On‑the‑fly square crops with data augmentation, balancing positive / background patches.""" + def __init__(self, base_ds, patch=64, pos_ratio=0.5, retries=30, augment=False, seed=None): + """ + Parameters: + ----------- + base_ds : XPointDataset + Base dataset containing full frames + patch : int + Size of square patches to extract + pos_ratio : float + Target ratio of patches containing X-points + retries : int + Number of attempts to find a suitable patch + augment : bool + If True, apply on-the-fly data augmentation (use for training only) + seed : int or None + Random seed for reproducibility (None for non-deterministic) + """ self.base_ds = base_ds self.patch = patch self.pos_ratio = pos_ratio self.retries = retries - self.rng = np.random.default_rng() + self.augment = augment + + # Initialize RNG with seed if provided + if seed is not None: + self.rng = np.random.default_rng(seed) + else: + self.rng = np.random.default_rng() def __len__(self): # give each full frame K random crops per epoch (K=32 for more samples) @@ -314,6 +365,66 @@ def __len__(self): def _crop(self, arr, top, left): return arr[..., top:top+self.patch, left:left+self.patch] + def _apply_augmentation(self, all_data, mask): + """ + Apply random data augmentation to improve generalization + + Augmentations applied: + - Random rotation (90°, 180°, 270°) + - Random horizontal flip + - Random vertical flip + - Gaussian noise injection + - Random brightness/contrast adjustment + - Cutout (random erasing) + """ + if not self.augment: + return all_data, mask + + # 1. Random rotation (0, 90, 180, 270 degrees) + # 75% chance to apply rotation + if self.rng.random() < 0.75: + k = self.rng.integers(1, 4) # 1, 2, or 3 (90°, 180°, 270°) + all_data = torch.rot90(all_data, k=k, dims=(-2, -1)) + mask = torch.rot90(mask, k=k, dims=(-2, -1)) + + # 2. Random horizontal flip (50% chance) + if self.rng.random() < 0.5: + all_data = torch.flip(all_data, dims=(-1,)) + mask = torch.flip(mask, dims=(-1,)) + + # 3. Random vertical flip (50% chance) + if self.rng.random() < 0.5: + all_data = torch.flip(all_data, dims=(-2,)) + mask = torch.flip(mask, dims=(-2,)) + + # 4. Add Gaussian noise (30% chance) + # Small noise helps prevent overfitting to exact pixel values + if self.rng.random() < 0.3: + noise_std = self.rng.uniform(0.005, 0.02) + noise = torch.randn_like(all_data) * noise_std + all_data = all_data + noise + + # 5. Random brightness/contrast adjustment per channel (30% chance) + # Helps model become invariant to intensity variations + if self.rng.random() < 0.3: + for c in range(all_data.shape[0]): + brightness = self.rng.uniform(-0.1, 0.1) + contrast = self.rng.uniform(0.9, 1.1) + mean = all_data[c].mean() + all_data[c] = contrast * (all_data[c] - mean) + mean + brightness + + # 6. Cutout/Random erasing (20% chance) + # Prevents model from relying too heavily on specific spatial features + if self.rng.random() < 0.2: + h, w = all_data.shape[-2:] + cutout_size = int(min(h, w) * self.rng.uniform(0.1, 0.25)) + if cutout_size > 0: + y = self.rng.integers(0, max(1, h - cutout_size)) + x = self.rng.integers(0, max(1, w - cutout_size)) + all_data[..., y:y+cutout_size, x:x+cutout_size] = 0 + + return all_data, mask + def __getitem__(self, _): frame = self.base_ds[self.rng.integers(len(self.base_ds))] H, W = frame["mask"].shape[-2:] @@ -321,9 +432,15 @@ def __getitem__(self, _): # Ensure we have enough space for cropping if H < self.patch or W < self.patch: # Return padded version if image is too small + all_data = F.pad(frame["all"], (0, max(0, self.patch - W), 0, max(0, self.patch - H))) + mask = F.pad(frame["mask"], (0, max(0, self.patch - W), 0, max(0, self.patch - H))) + + # Apply augmentation if enabled + all_data, mask = self._apply_augmentation(all_data, mask) + return { - "all": F.pad(frame["all"], (0, max(0, self.patch - W), 0, max(0, self.patch - H))), - "mask": F.pad(frame["mask"], (0, max(0, self.patch - W), 0, max(0, self.patch - H))) + "all": all_data, + "mask": mask } for attempt in range(self.retries): @@ -334,8 +451,13 @@ def __getitem__(self, _): want_pos = (attempt / self.retries) < self.pos_ratio if has_pos == want_pos or attempt == self.retries - 1: + all_crop = self._crop(frame["all"], y0, x0) + + # Apply augmentation if enabled + all_crop, crop_mask = self._apply_augmentation(all_crop, crop_mask) + return { - "all" : self._crop(frame["all"], y0, x0), + "all": all_crop, "mask": crop_mask } @@ -381,7 +503,7 @@ class UNet(nn.Module): """ Improved U-Net with residual blocks and better normalization """ - def __init__(self, input_channels=4, base_channels=32): + def __init__(self, input_channels=4, base_channels=32, *, dropout_rate): super().__init__() # Encoder @@ -391,23 +513,28 @@ def __init__(self, input_channels=4, base_channels=32): self.enc4 = ResidualBlock(base_channels*4, base_channels*8) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) - self.dropout = nn.Dropout2d(0.1) + self.dropout = nn.Dropout2d(dropout_rate) # Bottleneck self.bottleneck = ResidualBlock(base_channels*8, base_channels*16) + self.bottleneck_dropout = nn.Dropout2d(dropout_rate) # Decoder self.up4 = nn.ConvTranspose2d(base_channels*16, base_channels*8, kernel_size=2, stride=2) self.dec4 = ResidualBlock(base_channels*16, base_channels*8) + self.dec4_dropout = nn.Dropout2d(dropout_rate) self.up3 = nn.ConvTranspose2d(base_channels*8, base_channels*4, kernel_size=2, stride=2) self.dec3 = ResidualBlock(base_channels*8, base_channels*4) + self.dec3_dropout = nn.Dropout2d(dropout_rate) self.up2 = nn.ConvTranspose2d(base_channels*4, base_channels*2, kernel_size=2, stride=2) self.dec2 = ResidualBlock(base_channels*4, base_channels*2) + self.dec2_dropout = nn.Dropout2d(dropout_rate) self.up1 = nn.ConvTranspose2d(base_channels*2, base_channels, kernel_size=2, stride=2) self.dec1 = ResidualBlock(base_channels*2, base_channels) + self.dec1_dropout = nn.Dropout2d(dropout_rate) self.out_conv = nn.Conv2d(base_channels, 1, kernel_size=1) @@ -430,23 +557,28 @@ def forward(self, x): # Bottleneck b = self.bottleneck(p4) + b = self.bottleneck_dropout(b) # Decoder u4 = self.up4(b) cat4 = torch.cat([u4, e4], dim=1) d4 = self.dec4(cat4) + d4 = self.dec4_dropout(d4) u3 = self.up3(d4) cat3 = torch.cat([u3, e3], dim=1) d3 = self.dec3(cat3) + d3 = self.dec3_dropout(d3) u2 = self.up2(d3) cat2 = torch.cat([u2, e2], dim=1) d2 = self.dec2(cat2) + d2 = self.dec2_dropout(d2) u1 = self.up1(d2) cat1 = torch.cat([u1, e1], dim=1) d1 = self.dec1(cat1) + d1 = self.dec1_dropout(d1) out = self.out_conv(d1) return out @@ -728,6 +860,10 @@ def parseCommandLineArgs(): parser = argparse.ArgumentParser(description='ML-based reconnection classifier') parser.add_argument('--learningRate', type=float, default=1e-5, help='specify the learning rate') + parser.add_argument('--weightDecay', type=float, default=1e-4, + help='specify the weight decay (L2 regularization) for optimizer') + parser.add_argument('--dropoutRate', type=float, default=0.2, + help='specify the dropout rate for regularization') parser.add_argument('--batchSize', type=int, default=1, help='specify the batch size') parser.add_argument('--epochs', type=int, default=2000, @@ -745,7 +881,7 @@ def parseCommandLineArgs(): minimum reduction in training loss in orders of magnitude, set to 0 to disable the check (default: 3) ''') - parser.add_argument('--checkPointFrequency', type=int, default=10, + parser.add_argument('--checkPointFrequency', type=int, default=100, help='number of epochs between checkpoints') parser.add_argument('--paramFile', type=Path, default=None, help=''' @@ -773,6 +909,10 @@ def parseCommandLineArgs(): help='path to save benchmark results JSON file (default: ./benchmark_results.json)') parser.add_argument('--eval-output', type=Path, default='./evaluation_metrics.json', help='path to save evaluation metrics JSON file (default: ./evaluation_metrics.json)') + parser.add_argument('--seed', type=int, default=None, + help='random seed for reproducibility (default: None for non-deterministic)') + parser.add_argument('--require-gpu', action='store_true', + help='require GPU to be available, exit if not found') # CI TEST: Add smoke test flag parser.add_argument('--smoke-test', action='store_true', @@ -931,6 +1071,9 @@ def load_model_checkpoint(model, optimizer, checkpoint_path, scaler=None): def main(): args = parseCommandLineArgs() + # Set seed for reproducibility + set_seed(args.seed) + # CI TEST: Override parameters for smoke test if args.smoke_test: print("=" * 60) @@ -972,53 +1115,79 @@ def main(): val_dataset = SyntheticXPointDataset(nframes=1, shape=(64, 64), nxpoints=3, seed=123) print(f"Created synthetic datasets: {len(train_dataset)} train, {len(val_dataset)} val frames") else: - # Original data loading + # Original data loading - NO STATIC AUGMENTATION train_fnums = range(args.trainFrameFirst, args.trainFrameLast) val_fnums = range(args.validationFrameFirst, args.validationFrameLast) print(f"Loading training data from frames {args.trainFrameFirst} to {args.trainFrameLast-1}") print(f"Loading validation data from frames {args.validationFrameFirst} to {args.validationFrameLast-1}") + # Set rotateAndReflect=False - we'll use on-the-fly augmentation instead train_dataset = XPointDataset(args.paramFile, train_fnums, - xptCacheDir=args.xptCacheDir, rotateAndReflect=True) + xptCacheDir=args.xptCacheDir, rotateAndReflect=False) val_dataset = XPointDataset(args.paramFile, val_fnums, - xptCacheDir=args.xptCacheDir) + xptCacheDir=args.xptCacheDir, rotateAndReflect=False) - # Use consistent pos_ratio for both training and validation - train_crop = XPointPatchDataset(train_dataset, patch=64, pos_ratio=0.5, retries=30) - val_crop = XPointPatchDataset(val_dataset, patch=64, pos_ratio=0.5, retries=30) + # Enable augmentation for training, disable for validation + train_crop = XPointPatchDataset(train_dataset, patch=64, pos_ratio=0.5, retries=30, + augment=True, seed=args.seed) + val_crop = XPointPatchDataset(val_dataset, patch=64, pos_ratio=0.5, retries=30, + augment=False, seed=args.seed) t1 = timer() print("time (s) to create gkyl data loader: " + str(t1-t0)) - print(f"number of training frames (original + augmented): {len(train_dataset)}") + print(f"number of training frames: {len(train_dataset)}") print(f"number of validation frames: {len(val_dataset)}") print(f"number of training patches per epoch: {len(train_crop)}") print(f"number of validation patches per epoch: {len(val_crop)}") + print(f"Data augmentation: ENABLED for training, DISABLED for validation") + if args.seed is not None: + print(f"Random seed: {args.seed} (reproducible mode)") + else: + print(f"Random seed: None (non-deterministic mode)") train_loader = DataLoader(train_crop, batch_size=args.batchSize, shuffle=True, num_workers=0) val_loader = DataLoader(val_crop, batch_size=args.batchSize, shuffle=False, num_workers=0) + # Check GPU requirement + if args.require_gpu and not torch.cuda.is_available(): + print("=" * 60) + print("ERROR: GPU required but not available!") + print("=" * 60) + print("The --require-gpu flag was set, but CUDA is not available.") + print("Please check:") + print(" 1. NVIDIA GPU is properly installed") + print(" 2. CUDA drivers are installed") + print(" 3. PyTorch was installed with CUDA support") + print("\nExiting...") + sys.exit(1) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") + if args.require_gpu: + print("GPU requirement: ENABLED (will exit if GPU not available)") + # Initialize benchmark tracker benchmark = TrainingBenchmark(device, enabled=args.benchmark) if args.benchmark: benchmark.print_hardware_info() # Use the improved model - model = UNet(input_channels=4, base_channels=32).to(device) + model = UNet(input_channels=4, base_channels=32, dropout_rate=args.dropoutRate).to(device) # Count parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters: {total_params:,}") print(f"Trainable parameters: {trainable_params:,}") + print(f"Dropout rate: {args.dropoutRate}") criterion = DiceLoss(smooth=1.0) # Use AdamW optimizer with weight decay for better generalization - optimizer = optim.AdamW(model.parameters(), lr=args.learningRate, weight_decay=1e-5) + optimizer = optim.AdamW(model.parameters(), lr=args.learningRate, weight_decay=args.weightDecay) + print(f"Optimizer: AdamW with learning_rate={args.learningRate}, weight_decay={args.weightDecay}") # Learning rate scheduler with cosine annealing scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)