Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions XPointMLTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
# Import evaluation metrics module
from eval_metrics import ModelEvaluator, evaluate_model_on_dataset

# Import git utils
from git_utils import print_git_info

def set_seed(seed):
"""
Set random seed for reproducibility across all libraries
Expand Down Expand Up @@ -381,8 +384,8 @@ def _apply_augmentation(self, all_data, mask):
return all_data, mask

# 1. Random rotation (0, 90, 180, 270 degrees)
# 75% chance to apply rotation
if self.rng.random() < 0.75:
# 50% chance to apply rotation
if self.rng.random() < 0.50:
k = self.rng.integers(1, 4) # 1, 2, or 3 (90°, 180°, 270°)
all_data = torch.rot90(all_data, k=k, dims=(-2, -1))
mask = torch.rot90(mask, k=k, dims=(-2, -1))
Expand All @@ -397,25 +400,26 @@ def _apply_augmentation(self, all_data, mask):
all_data = torch.flip(all_data, dims=(-2,))
mask = torch.flip(mask, dims=(-2,))

# 4. Add Gaussian noise (30% chance)
# 4. Add Gaussian noise (10% chance)
# Small noise helps prevent overfitting to exact pixel values
if self.rng.random() < 0.3:
if self.rng.random() < 0.1:
noise_std = self.rng.uniform(0.005, 0.02)
noise = torch.randn_like(all_data) * noise_std
all_data = all_data + noise

# 5. Random brightness/contrast adjustment per channel (30% chance)
# Helps model become invariant to intensity variations
# 5. Random brightness/contrast adjustment (30% chance)
# CHANGED: Applied globally across channels to preserve physical relationships
# (e.g., keeping the derivative relationship between psi and B fields)
if self.rng.random() < 0.3:
for c in range(all_data.shape[0]):
brightness = self.rng.uniform(-0.1, 0.1)
contrast = self.rng.uniform(0.9, 1.1)
mean = all_data[c].mean()
all_data[c] = contrast * (all_data[c] - mean) + mean + brightness
brightness = self.rng.uniform(-0.1, 0.1)
contrast = self.rng.uniform(0.9, 1.1)
# Apply same transformation to all channels
mean = all_data.mean(dim=(-2, -1), keepdim=True)
all_data = contrast * (all_data - mean) + mean + brightness

# 6. Cutout/Random erasing (20% chance)
# 6. Cutout/Random erasing (5% chance)
# Prevents model from relying too heavily on specific spatial features
if self.rng.random() < 0.2:
if self.rng.random() < 0.05:
h, w = all_data.shape[-2:]
cutout_size = int(min(h, w) * self.rng.uniform(0.1, 0.25))
if cutout_size > 0:
Expand Down Expand Up @@ -1069,6 +1073,11 @@ def load_model_checkpoint(model, optimizer, checkpoint_path, scaler=None):


def main():
# Print git repository information
# Use the directory of the current script as the repo path
repo_path = os.path.dirname(os.path.abspath(__file__))
print_git_info(repo_path)

args = parseCommandLineArgs()

# Set seed for reproducibility
Expand Down
67 changes: 67 additions & 0 deletions git_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import subprocess
import os

"""
References::
https://gitpython.readthedocs.io/en/stable/tutorial.html
"""
def get_git_info(repo_path='.'):
"""
Retrieves git information: commit hash, remote URL, and branch name.

Args:
repo_path (str): Path to the git repository. Defaults to current directory.

Returns:
dict: Dictionary containing 'commit_hash', 'remote_url', and 'branch_name'.
Values are None if retrieval fails.
"""
def run_git_command(command):
try:
result = subprocess.run(
command,
cwd=repo_path,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True
)
return result.stdout.strip()
except (subprocess.CalledProcessError, FileNotFoundError):
return None

commit_hash = run_git_command(['git', 'rev-parse', 'HEAD'])
remote_url = run_git_command(['git', 'config', '--get', 'remote.origin.url'])
branch_name = run_git_command(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])

return {
'commit_hash': commit_hash,
'remote_url': remote_url,
'branch_name': branch_name
}

def print_git_info(repo_path='.'):
"""
Prints git information to stdout.
"""
info = get_git_info(repo_path)
print("-" * 30)
print("Git Repository Information:")
if info['commit_hash']:
print(f"Commit Hash: {info['commit_hash']}")
else:
print("Commit Hash: Not available")

if info['branch_name']:
print(f"Branch Name: {info['branch_name']}")
else:
print("Branch Name: Not available")

if info['remote_url']:
print(f"Remote URL: {info['remote_url']}")
else:
print("Remote URL: Not available")
print("-" * 30)

if __name__ == "__main__":
print_git_info()
175 changes: 175 additions & 0 deletions hessian_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import time
import torch
import numpy as np
import argparse
import sys
import os
from pathlib import Path

# Ensure we can import from the current directory and pgkylFrontEnd (via PYTHONPATH)
sys.path.append(str(Path(__file__).parent))

try:
from XPointMLTest import UNet, loadPgkylDataFromCache, cachedPgkylDataExists
from utils import auxFuncs, gkData
except ImportError as e:
print(f"Error importing modules: {e}")
print("Make sure you have sourced envPyTorch.sh and are running with correct PYTHONPATH.")
sys.exit(1)

def compare_hessian_vs_ml(param_file, cache_dir, model_path, frame_list, device='cuda'):
print(f"Comparing Hessian vs ML on frames {frame_list}")
print(f"Device: {device}")
print(f"Model: {model_path}")
print(f"Param File: {param_file}")
print(f"Cache Dir: {cache_dir}")

# Load Model
# UNet signature: def __init__(self, input_channels=4, base_channels=32, *, dropout_rate):
model = UNet(input_channels=4, base_channels=32, dropout_rate=0.15).to(device)
try:
checkpoint = torch.load(model_path, map_location=device)
if 'model_state_dict' in checkpoint:
model.load_state_dict(checkpoint['model_state_dict'])
else:
model.load_state_dict(checkpoint)
except Exception as e:
print(f"Failed to load model: {e}")
sys.exit(1)

model.eval()

hessian_times = []
ml_times = []

# Warmup
dummy_input = torch.randn(1, 4, 1024, 1024).to(device)
with torch.no_grad():
_ = model(dummy_input)

print(f"{'Frame':<10} | {'Hessian (s)':<15} | {'ML (s)':<15} | {'Speedup':<10}")
print("-" * 60)

cache_path = Path(cache_dir) if cache_dir else None

for fnum in frame_list:
try:
psi = None
dx = None

# Try loading from cache first
if cache_path and cachedPgkylDataExists(cache_path, fnum, "psi"):
fields_to_load = {"psi": None, "coords": None}
loaded = loadPgkylDataFromCache(cache_path, fnum, fields_to_load)
psi = loaded["psi"]
coords = loaded["coords"]
# Calculate dx from coords
dx = [c[1] - c[0] for c in coords]
else:
# Fallback to gkData (might fail if getData.py is buggy)
params = {}
params["polyOrderOverride"] = 0
var = gkData.gkData(str(param_file), fnum, 'psi', params).compactRead()
psi = var.data
dx = var.dx

if psi is None:
print(f"Could not load data for frame {fnum}")
continue

# --- Measure Hessian Time ---
t0 = time.time()

# Replicating Hessian logic
critPoints = auxFuncs.getCritPoints(psi)
[xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints)

t1 = time.time()
hessian_time = t1 - t0
hessian_times.append(hessian_time)

# --- Measure ML Time ---
# Preprocess - Calculate derived fields
[df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,dx)
[d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,dx)
[d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,dx)
bx = df_dy
by = -df_dx
# mu0 is usually 1.0 in normalized units or available in var.mu0
# If we loaded from cache, we don't have var.mu0.
# Assuming mu0=1.0 for now as it's common in normalized simulations,
# or we could read it from param file, but let's stick to 1.0 or check if we can get it.
# In XPointMLTest.py: jz = -(d2f_dxdx + d2f_dydy) / var.mu0
# In getConst.py: self.mu0 = mu0 (from param file).
# Let's assume mu0=1.0 to avoid reading param file again, or just use 1.0.
mu0 = 1.0
jz = -(d2f_dxdx + d2f_dydy) / mu0

# Normalize (using same logic as XPointMLTest.py)
psi_norm = (psi - psi.mean()) / (psi.std() + 1e-8)
bx_norm = (bx - bx.mean()) / (bx.std() + 1e-8)
by_norm = (by - by.mean()) / (by.std() + 1e-8)
jz_norm = (jz - jz.mean()) / (jz.std() + 1e-8)

# Stack
psi_torch = torch.from_numpy(psi_norm).float().unsqueeze(0)
bx_torch = torch.from_numpy(bx_norm).float().unsqueeze(0)
by_torch = torch.from_numpy(by_norm).float().unsqueeze(0)
jz_torch = torch.from_numpy(jz_norm).float().unsqueeze(0)

input_tensor = torch.cat((psi_torch, bx_torch, by_torch, jz_torch)).unsqueeze(0).to(device)

if device == 'cuda':
torch.cuda.synchronize()
t2 = time.time()

with torch.no_grad():
output = model(input_tensor)
prob = torch.sigmoid(output)
mask = (prob > 0.5).float()

if device == 'cuda':
torch.cuda.synchronize()
t3 = time.time()

ml_time = t3 - t2
ml_times.append(ml_time)

print(f"{fnum:<10} | {hessian_time:<15.4f} | {ml_time:<15.4f} | {hessian_time/ml_time:<10.2f}")

except Exception as e:
print(f"Error processing frame {fnum}: {e}")
import traceback
traceback.print_exc()
continue

if hessian_times and ml_times:
avg_hessian = np.mean(hessian_times)
avg_ml = np.mean(ml_times)

print("\n" + "="*60)
print(f"Average Hessian Time: {avg_hessian:.4f}s")
print(f"Average ML Time: {avg_ml:.4f}s")
print(f"Average Speedup: {avg_hessian/avg_ml:.2f}x")
print("="*60)
else:
print("No frames processed successfully.")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Compare Hessian-based vs ML-based X-point detection performance.")
parser.add_argument('--paramFile', type=str, required=True, help="Path to the parameter file")
parser.add_argument('--xptCacheDir', type=str, default=None, help="Path to cache directory (optional)")
parser.add_argument('--modelPath', type=str, required=True, help="Path to the trained model checkpoint (.pt)")
parser.add_argument('--frames', type=str, default="141-150", help="Range of frames (e.g., '141-150' or '141,142,143')")
parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run ML model on")

args = parser.parse_args()

# Parse frames
if '-' in args.frames:
start, end = map(int, args.frames.split('-'))
frames = range(start, end + 1)
else:
frames = [int(x) for x in args.frames.split(',')]

compare_hessian_vs_ml(args.paramFile, args.xptCacheDir, args.modelPath, frames, args.device)