diff --git a/XPointMLTest.py b/XPointMLTest.py index 3027485..31cbc11 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -31,6 +31,9 @@ # Import evaluation metrics module from eval_metrics import ModelEvaluator, evaluate_model_on_dataset +# Import git utils +from git_utils import print_git_info + def set_seed(seed): """ Set random seed for reproducibility across all libraries @@ -381,8 +384,8 @@ def _apply_augmentation(self, all_data, mask): return all_data, mask # 1. Random rotation (0, 90, 180, 270 degrees) - # 75% chance to apply rotation - if self.rng.random() < 0.75: + # 50% chance to apply rotation + if self.rng.random() < 0.50: k = self.rng.integers(1, 4) # 1, 2, or 3 (90°, 180°, 270°) all_data = torch.rot90(all_data, k=k, dims=(-2, -1)) mask = torch.rot90(mask, k=k, dims=(-2, -1)) @@ -397,25 +400,26 @@ def _apply_augmentation(self, all_data, mask): all_data = torch.flip(all_data, dims=(-2,)) mask = torch.flip(mask, dims=(-2,)) - # 4. Add Gaussian noise (30% chance) + # 4. Add Gaussian noise (10% chance) # Small noise helps prevent overfitting to exact pixel values - if self.rng.random() < 0.3: + if self.rng.random() < 0.1: noise_std = self.rng.uniform(0.005, 0.02) noise = torch.randn_like(all_data) * noise_std all_data = all_data + noise - # 5. Random brightness/contrast adjustment per channel (30% chance) - # Helps model become invariant to intensity variations + # 5. Random brightness/contrast adjustment (30% chance) + # CHANGED: Applied globally across channels to preserve physical relationships + # (e.g., keeping the derivative relationship between psi and B fields) if self.rng.random() < 0.3: - for c in range(all_data.shape[0]): - brightness = self.rng.uniform(-0.1, 0.1) - contrast = self.rng.uniform(0.9, 1.1) - mean = all_data[c].mean() - all_data[c] = contrast * (all_data[c] - mean) + mean + brightness + brightness = self.rng.uniform(-0.1, 0.1) + contrast = self.rng.uniform(0.9, 1.1) + # Apply same transformation to all channels + mean = all_data.mean(dim=(-2, -1), keepdim=True) + all_data = contrast * (all_data - mean) + mean + brightness - # 6. Cutout/Random erasing (20% chance) + # 6. Cutout/Random erasing (5% chance) # Prevents model from relying too heavily on specific spatial features - if self.rng.random() < 0.2: + if self.rng.random() < 0.05: h, w = all_data.shape[-2:] cutout_size = int(min(h, w) * self.rng.uniform(0.1, 0.25)) if cutout_size > 0: @@ -1069,6 +1073,11 @@ def load_model_checkpoint(model, optimizer, checkpoint_path, scaler=None): def main(): + # Print git repository information + # Use the directory of the current script as the repo path + repo_path = os.path.dirname(os.path.abspath(__file__)) + print_git_info(repo_path) + args = parseCommandLineArgs() # Set seed for reproducibility diff --git a/git_utils.py b/git_utils.py new file mode 100644 index 0000000..ac04e34 --- /dev/null +++ b/git_utils.py @@ -0,0 +1,67 @@ +import subprocess +import os + +""" +References:: +https://gitpython.readthedocs.io/en/stable/tutorial.html +""" +def get_git_info(repo_path='.'): + """ + Retrieves git information: commit hash, remote URL, and branch name. + + Args: + repo_path (str): Path to the git repository. Defaults to current directory. + + Returns: + dict: Dictionary containing 'commit_hash', 'remote_url', and 'branch_name'. + Values are None if retrieval fails. + """ + def run_git_command(command): + try: + result = subprocess.run( + command, + cwd=repo_path, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=True + ) + return result.stdout.strip() + except (subprocess.CalledProcessError, FileNotFoundError): + return None + + commit_hash = run_git_command(['git', 'rev-parse', 'HEAD']) + remote_url = run_git_command(['git', 'config', '--get', 'remote.origin.url']) + branch_name = run_git_command(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + + return { + 'commit_hash': commit_hash, + 'remote_url': remote_url, + 'branch_name': branch_name + } + +def print_git_info(repo_path='.'): + """ + Prints git information to stdout. + """ + info = get_git_info(repo_path) + print("-" * 30) + print("Git Repository Information:") + if info['commit_hash']: + print(f"Commit Hash: {info['commit_hash']}") + else: + print("Commit Hash: Not available") + + if info['branch_name']: + print(f"Branch Name: {info['branch_name']}") + else: + print("Branch Name: Not available") + + if info['remote_url']: + print(f"Remote URL: {info['remote_url']}") + else: + print("Remote URL: Not available") + print("-" * 30) + +if __name__ == "__main__": + print_git_info() diff --git a/hessian_comparison.py b/hessian_comparison.py new file mode 100644 index 0000000..16e6a90 --- /dev/null +++ b/hessian_comparison.py @@ -0,0 +1,175 @@ +import time +import torch +import numpy as np +import argparse +import sys +import os +from pathlib import Path + +# Ensure we can import from the current directory and pgkylFrontEnd (via PYTHONPATH) +sys.path.append(str(Path(__file__).parent)) + +try: + from XPointMLTest import UNet, loadPgkylDataFromCache, cachedPgkylDataExists + from utils import auxFuncs, gkData +except ImportError as e: + print(f"Error importing modules: {e}") + print("Make sure you have sourced envPyTorch.sh and are running with correct PYTHONPATH.") + sys.exit(1) + +def compare_hessian_vs_ml(param_file, cache_dir, model_path, frame_list, device='cuda'): + print(f"Comparing Hessian vs ML on frames {frame_list}") + print(f"Device: {device}") + print(f"Model: {model_path}") + print(f"Param File: {param_file}") + print(f"Cache Dir: {cache_dir}") + + # Load Model + # UNet signature: def __init__(self, input_channels=4, base_channels=32, *, dropout_rate): + model = UNet(input_channels=4, base_channels=32, dropout_rate=0.15).to(device) + try: + checkpoint = torch.load(model_path, map_location=device) + if 'model_state_dict' in checkpoint: + model.load_state_dict(checkpoint['model_state_dict']) + else: + model.load_state_dict(checkpoint) + except Exception as e: + print(f"Failed to load model: {e}") + sys.exit(1) + + model.eval() + + hessian_times = [] + ml_times = [] + + # Warmup + dummy_input = torch.randn(1, 4, 1024, 1024).to(device) + with torch.no_grad(): + _ = model(dummy_input) + + print(f"{'Frame':<10} | {'Hessian (s)':<15} | {'ML (s)':<15} | {'Speedup':<10}") + print("-" * 60) + + cache_path = Path(cache_dir) if cache_dir else None + + for fnum in frame_list: + try: + psi = None + dx = None + + # Try loading from cache first + if cache_path and cachedPgkylDataExists(cache_path, fnum, "psi"): + fields_to_load = {"psi": None, "coords": None} + loaded = loadPgkylDataFromCache(cache_path, fnum, fields_to_load) + psi = loaded["psi"] + coords = loaded["coords"] + # Calculate dx from coords + dx = [c[1] - c[0] for c in coords] + else: + # Fallback to gkData (might fail if getData.py is buggy) + params = {} + params["polyOrderOverride"] = 0 + var = gkData.gkData(str(param_file), fnum, 'psi', params).compactRead() + psi = var.data + dx = var.dx + + if psi is None: + print(f"Could not load data for frame {fnum}") + continue + + # --- Measure Hessian Time --- + t0 = time.time() + + # Replicating Hessian logic + critPoints = auxFuncs.getCritPoints(psi) + [xpts, optsMax, optsMin] = auxFuncs.getXOPoints(psi, critPoints) + + t1 = time.time() + hessian_time = t1 - t0 + hessian_times.append(hessian_time) + + # --- Measure ML Time --- + # Preprocess - Calculate derived fields + [df_dx,df_dy,df_dz] = auxFuncs.genGradient(psi,dx) + [d2f_dxdx,d2f_dxdy,d2f_dxdz] = auxFuncs.genGradient(df_dx,dx) + [d2f_dydx,d2f_dydy,d2f_dydz] = auxFuncs.genGradient(df_dy,dx) + bx = df_dy + by = -df_dx + # mu0 is usually 1.0 in normalized units or available in var.mu0 + # If we loaded from cache, we don't have var.mu0. + # Assuming mu0=1.0 for now as it's common in normalized simulations, + # or we could read it from param file, but let's stick to 1.0 or check if we can get it. + # In XPointMLTest.py: jz = -(d2f_dxdx + d2f_dydy) / var.mu0 + # In getConst.py: self.mu0 = mu0 (from param file). + # Let's assume mu0=1.0 to avoid reading param file again, or just use 1.0. + mu0 = 1.0 + jz = -(d2f_dxdx + d2f_dydy) / mu0 + + # Normalize (using same logic as XPointMLTest.py) + psi_norm = (psi - psi.mean()) / (psi.std() + 1e-8) + bx_norm = (bx - bx.mean()) / (bx.std() + 1e-8) + by_norm = (by - by.mean()) / (by.std() + 1e-8) + jz_norm = (jz - jz.mean()) / (jz.std() + 1e-8) + + # Stack + psi_torch = torch.from_numpy(psi_norm).float().unsqueeze(0) + bx_torch = torch.from_numpy(bx_norm).float().unsqueeze(0) + by_torch = torch.from_numpy(by_norm).float().unsqueeze(0) + jz_torch = torch.from_numpy(jz_norm).float().unsqueeze(0) + + input_tensor = torch.cat((psi_torch, bx_torch, by_torch, jz_torch)).unsqueeze(0).to(device) + + if device == 'cuda': + torch.cuda.synchronize() + t2 = time.time() + + with torch.no_grad(): + output = model(input_tensor) + prob = torch.sigmoid(output) + mask = (prob > 0.5).float() + + if device == 'cuda': + torch.cuda.synchronize() + t3 = time.time() + + ml_time = t3 - t2 + ml_times.append(ml_time) + + print(f"{fnum:<10} | {hessian_time:<15.4f} | {ml_time:<15.4f} | {hessian_time/ml_time:<10.2f}") + + except Exception as e: + print(f"Error processing frame {fnum}: {e}") + import traceback + traceback.print_exc() + continue + + if hessian_times and ml_times: + avg_hessian = np.mean(hessian_times) + avg_ml = np.mean(ml_times) + + print("\n" + "="*60) + print(f"Average Hessian Time: {avg_hessian:.4f}s") + print(f"Average ML Time: {avg_ml:.4f}s") + print(f"Average Speedup: {avg_hessian/avg_ml:.2f}x") + print("="*60) + else: + print("No frames processed successfully.") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Compare Hessian-based vs ML-based X-point detection performance.") + parser.add_argument('--paramFile', type=str, required=True, help="Path to the parameter file") + parser.add_argument('--xptCacheDir', type=str, default=None, help="Path to cache directory (optional)") + parser.add_argument('--modelPath', type=str, required=True, help="Path to the trained model checkpoint (.pt)") + parser.add_argument('--frames', type=str, default="141-150", help="Range of frames (e.g., '141-150' or '141,142,143')") + parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to run ML model on") + + args = parser.parse_args() + + # Parse frames + if '-' in args.frames: + start, end = map(int, args.frames.split('-')) + frames = range(start, end + 1) + else: + frames = [int(x) for x in args.frames.split(',')] + + compare_hessian_vs_ml(args.paramFile, args.xptCacheDir, args.modelPath, frames, args.device)