Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ The classifier supports several command line options for training configuration:
- `--plotDir`: Directory where figures are written (default: `./plots`)
- `--checkPointFrequency`: Number of epochs between model checkpoints (default: 10)

### Performance Benchmarking
- `--benchmark`: Enable performance benchmarking (tracks timing, throughput, GPU memory)
- `--benchmark-output`: Path to save benchmark results JSON file (default: `./benchmark_results.json`)
- `--eval-output`: Path to save evaluation metrics JSON file (default: `./evaluation_metrics.json`)

### Testing
- `--smoke-test`: Run minimal smoke test for CI (overrides other parameters for quick validation)

Expand Down Expand Up @@ -138,4 +143,51 @@ The following commands should be run on `checkers` **every time you create a new
cd nsfCssiMlClassifier
source envPyTorch.sh
source pgkyl/bin/activate
```
```

## Model Evaluation Metrics

The model evaluation system measures how well the classifier identifies X-points (magnetic reconnection sites) by treating it as a pixel-level binary classification problem.

### Key Metrics

The evaluation outputs several metrics saved to JSON files:

- **Accuracy**: Overall pixel classification correctness (can be misleading due to class imbalance)
- **Precision**: Fraction of detected X-points that are correct (measures false alarm rate)
- **Recall**: Fraction of actual X-points that were found (measures miss rate)
- **F1 Score**: Harmonic mean of precision and recall (balanced performance metric)
- **IoU**: Intersection over Union - spatial overlap quality between predicted and actual X-point regions

### Understanding the Results

**Good performance indicators:**
- F1 Score > 0.8
- IoU > 0.5
- Similar metrics between training and validation sets (no overfitting)
- Low standard deviation across frames (consistent performance)

**Warning signs:**
- Large gap between training and validation metrics (overfitting)
- High precision but low recall (too conservative, missing X-points)
- Low precision but high recall (too aggressive, many false alarms)
- High frame-to-frame variation (inconsistent detection)

### Output Files

After training, the model produces:
- `evaluation_metrics.json`: Validation set performance
- `train_evaluation_metrics.json`: Training set performance
- Performance plots in the `plots/` directory showing:
- Training history (loss curves)
- Model predictions vs ground truth
- True positives (green), false positives (red), false negatives (yellow)

### Physics Context

For reconnection studies:
- **High recall is critical**: Missing X-points means missing reconnection events
- **Precision affects analysis**: False positives corrupt downstream calculations
- **IoU indicates localization**: Poor IoU means inaccurate X-point positions

The model uses a 9×9 pixel expansion around X-points to account for localization uncertainty while still requiring accurate region identification.
126 changes: 112 additions & 14 deletions XPointMLTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@

from ci_tests import SyntheticXPointDataset, test_checkpoint_functionality

# Import benchmark module
from benchmark import TrainingBenchmark

# Import evaluation metrics module
from eval_metrics import ModelEvaluator, evaluate_model_on_dataset

def expand_xpoints_mask(binary_mask, kernel_size=9):
"""
Expands each X-point in a binary mask to include surrounding cells
Expand Down Expand Up @@ -481,11 +487,17 @@ def forward(self, inputs, targets):
return 1.0 - dice

# TRAIN & VALIDATION UTILS
def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype):
def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark=None):
model.train()
running_loss = 0.0

# Start epoch timing for benchmark
if benchmark:
benchmark.start_epoch()

for batch in loader:
batch_start = timer()

all_data, mask = batch["all"].to(device), batch["mask"].to(device)

with autocast(device_type='cuda', dtype=amp_dtype, enabled=use_amp):
Expand Down Expand Up @@ -514,6 +526,15 @@ def train_one_epoch(model, loader, criterion, optimizer, device, scaler, use_amp
optimizer.step()

running_loss += loss.item()

# Record batch timing for benchmark
if benchmark:
batch_time = timer() - batch_start
benchmark.record_batch(all_data.size(0), batch_time)

# End epoch timing for benchmark
if benchmark:
benchmark.end_epoch()

return running_loss / len(loader) if len(loader) > 0 else 0.0

Expand Down Expand Up @@ -672,20 +693,20 @@ def plot_model_performance(psi_np, pred_prob_np, mask_gt, x, y, params, fnum, fi

plt.close()

def plot_training_history(train_losses, val_losses, save_path='plots/training_history.png'):
def plot_training_history(train_losses, val_loss, save_path='plots/training_history.png'):
"""
Plots training and validation losses across epochs.

Parameters:
train_losses (list): List of training losses for each epoch
val_losses (list): List of validation losses for each epoch
val_loss (list): List of validation losses for each epoch
save_path (str): Path to save the resulting plot
"""
plt.figure(figsize=(10, 6))
epochs = range(1, len(train_losses) + 1)

plt.plot(epochs, train_losses, 'b-', label='Training Loss')
plt.plot(epochs, val_losses, 'r-', label='Validation Loss')
plt.plot(epochs, val_loss, 'r-', label='Validation Loss')

plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
Expand All @@ -695,8 +716,8 @@ def plot_training_history(train_losses, val_losses, save_path='plots/training_hi
plt.grid(True, linestyle='--', alpha=0.7)

# Add some padding to y-axis to make visualization clearer
ymin = min(min(train_losses), min(val_losses)) * 0.9
ymax = max(max(train_losses), max(val_losses)) * 1.1
ymin = min(min(train_losses), min(val_loss)) * 0.9
ymax = max(max(train_losses), max(val_loss)) * 1.1
plt.ylim(ymin, ymax)

plt.savefig(save_path, dpi=300)
Expand Down Expand Up @@ -746,6 +767,12 @@ def parseCommandLineArgs():
choices=['float16', 'bfloat16'], help='data type for mixed precision (bfloat16 recommended)')
parser.add_argument('--patience', type=int, default=15,
help='patience for early stopping (default: 15)')
parser.add_argument('--benchmark', action='store_true',
help='enable performance benchmarking (tracks timing, throughput, GPU memory)')
parser.add_argument('--benchmark-output', type=Path, default='./benchmark_results.json',
help='path to save benchmark results JSON file (default: ./benchmark_results.json)')
parser.add_argument('--eval-output', type=Path, default='./evaluation_metrics.json',
help='path to save evaluation metrics JSON file (default: ./evaluation_metrics.json)')

# CI TEST: Add smoke test flag
parser.add_argument('--smoke-test', action='store_true',
Expand Down Expand Up @@ -974,6 +1001,11 @@ def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize benchmark tracker
benchmark = TrainingBenchmark(device, enabled=args.benchmark)
if args.benchmark:
benchmark.print_hardware_info()

# Use the improved model
model = UNet(input_channels=4, base_channels=32).to(device)

Expand Down Expand Up @@ -1028,14 +1060,21 @@ def main():

num_epochs = args.epochs
for epoch in range(start_epoch, num_epochs):
train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype)
train_loss_epoch = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler, use_amp, amp_dtype, benchmark)
val_loss_epoch = validate_one_epoch(model, val_loader, criterion, device, use_amp, amp_dtype)

train_loss.append(train_loss_epoch)
val_loss.append(val_loss_epoch)

current_lr = optimizer.param_groups[0]['lr']
print(f"[Epoch {epoch+1}/{num_epochs}] LR={current_lr:.2e} TrainLoss={train_loss[-1]:.6f} ValLoss={val_loss[-1]:.6f}")

# Enhanced logging with benchmark metrics
log_msg = f"[Epoch {epoch+1}/{num_epochs}] LR={current_lr:.2e} TrainLoss={train_loss[-1]:.6f} ValLoss={val_loss[-1]:.6f}"
if args.benchmark:
throughput = benchmark.get_throughput()
gpu_mem = benchmark.get_gpu_memory_usage()
log_msg += f" | Throughput={throughput:.2f} samples/s | GPU Mem={gpu_mem:.2f} GB"
print(log_msg)

# Learning rate scheduling
scheduler.step()
Expand All @@ -1059,8 +1098,12 @@ def main():
print(f"Early stopping triggered after {epoch+1} epochs (patience={args.patience})")
break

plot_training_history(train_loss, val_loss)
plot_training_history(train_loss, val_loss, save_path='plots/training_history.png')
print("time (s) to train model: " + str(timer()-t2))

# Print and save benchmark summary
if args.benchmark:
benchmark.print_summary(output_file=args.benchmark_output)

# CI TEST: Run additional tests if in smoke test mode
if args.smoke_test:
Expand Down Expand Up @@ -1134,12 +1177,71 @@ def main():
print("Loading best model for evaluation...")
model.load_state_dict(torch.load(best_model_path, weights_only=True))

# new evaluation code
# Evaluate model performance
if not args.smoke_test:
# print("\n" + "="*70)
# print("RUNNING MODEL EVALUATION")
# print("="*70)

# # Evaluate on validation set
# print("\nEvaluating on validation set...")
val_evaluator = evaluate_model_on_dataset(
model,
val_dataset, # Use original dataset, not patch dataset
device,
use_amp=use_amp,
amp_dtype=amp_dtype,
threshold=0.5
)

# Print and save validation metrics
val_evaluator.print_summary()
val_evaluator.save_json(args.eval_output)

# Evaluate on training set
print("\nEvaluating on training set...")
train_evaluator = evaluate_model_on_dataset(
model,
train_dataset,
device,
use_amp=use_amp,
amp_dtype=amp_dtype,
threshold=0.5
)

# Print and save training metrics
train_evaluator.print_summary()
train_eval_path = args.eval_output.parent / f"train_{args.eval_output.name}"
train_evaluator.save_json(train_eval_path)

# Compare training vs validation to check for overfitting
train_global = train_evaluator.get_global_metrics()
val_global = val_evaluator.get_global_metrics()

print("\n" + "="*70)
print("OVERFITTING CHECK")
print("="*70)
print(f"Training F1: {train_global['f1_score']:.4f}")
print(f"Validation F1: {val_global['f1_score']:.4f}")
print(f"Difference: {abs(train_global['f1_score'] - val_global['f1_score']):.4f}")

if train_global['f1_score'] - val_global['f1_score'] > 0.05:
print("⚠ Warning: Possible overfitting detected (train F1 >> val F1)")
elif val_global['f1_score'] - train_global['f1_score'] > 0.05:
print("⚠ Warning: Unusual pattern (val F1 >> train F1)")
else:
print("✓ Model generalizes well to validation set")
print("="*70 + "\n")

# ==================== END NEW EVALUATION CODE ====================

# (D) Plotting after training
model.eval() # switch to inference mode
outDir = "plots"
interpFac = 1

# Evaluate on combined set for demonstration. Exam this part to see if save to remove
# Evaluate on combined set for demonstration
if not args.smoke_test:
train_fnums = range(args.trainFrameFirst, args.trainFrameLast)
val_fnums = range(args.validationFrameFirst, args.validationFrameLast)
Expand Down Expand Up @@ -1175,10 +1277,6 @@ def main():

pred_mask_bin = (pred_prob_np[0,0] > 0.5).astype(np.float32)

print(f"Frame {fnum} rotation {rotation} reflectionAxis {reflectionAxis}:")
print(f" Probabilities - min: {pred_prob_np.min():.5f}, max: {pred_prob_np.max():.5f}, mean: {pred_prob_np.mean():.5f}")
print(f" Binary Mask (X-points) - count of 1s: {np.sum(pred_mask_bin)} / {pred_mask_bin.size} pixels")

if args.plot:
# Plot GROUND TRUTH
plot_psi_contours_and_xpoints(
Expand Down
Loading