From 7c1df428a988d7a224387bea63e19e11a3156a8e Mon Sep 17 00:00:00 2001 From: Cameron Smith Date: Tue, 10 Jun 2025 14:15:44 -0400 Subject: [PATCH 1/5] broken plotting of rotated frames --- XPointMLTest.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/XPointMLTest.py b/XPointMLTest.py index abc706f..e46dddd 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -59,12 +59,24 @@ def expand_xpoints_mask(binary_mask, kernel_size=9): return expanded_mask +def plotSimple(arr, outfile): + plt.imshow(arr, interpolation="nearest", origin="upper") + plt.colorbar() + plt.savefig(outfile) + plt.clf() + def rotate(frameData,deg): if deg not in [90, 180, 270]: print(f"invalid rotation specified... exiting") sys.exit() psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR) all = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR) + + plotSimple(all[0], f"{frameData['fnum']}_rotation{deg}_all0.png") + plotSimple(all[1], f"{frameData['fnum']}_rotation{deg}_all1.png") + plotSimple(all[2], f"{frameData['fnum']}_rotation{deg}_all2.png") + plotSimple(all[3], f"{frameData['fnum']}_rotation{deg}_all3.png") + mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) return { "fnum": frameData["fnum"], @@ -85,6 +97,10 @@ def reflect(frameData,axis): sys.exit() psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) all = torch.flip(frameData["all"], dims=(axis,)) + plotSimple(all[0], f"{frameData['fnum']}_reflectionAxis{axis}_all0.png") + plotSimple(all[1], f"{frameData['fnum']}_reflectionAxis{axis}_all1.png") + plotSimple(all[2], f"{frameData['fnum']}_reflectionAxis{axis}_all2.png") + plotSimple(all[3], f"{frameData['fnum']}_reflectionAxis{axis}_all3.png") mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) return { "fnum": frameData["fnum"], @@ -258,6 +274,10 @@ def load(self, fnum): by_torch = torch.from_numpy(fields["By"]).float().unsqueeze(0) jz_torch = torch.from_numpy(fields["Jz"]).float().unsqueeze(0) all_torch = torch.cat((psi_torch,bx_torch,by_torch,jz_torch)) # [4, Nx, Ny] + plotSimple(all_torch[0], f"{fnum}_all0.png") + plotSimple(all_torch[1], f"{fnum}_all1.png") + plotSimple(all_torch[2], f"{fnum}_all2.png") + plotSimple(all_torch[3], f"{fnum}_all3.png") mask_torch = torch.from_numpy(binaryMap).float().unsqueeze(0) # [1, Nx, Ny] if self.verbosity > 0: From 4bbf134926eca3f73167b1c23dc890c8d97a72bc Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Mon, 16 Jun 2025 09:18:20 -0400 Subject: [PATCH 2/5] removed bug where train_loss and val_loss are reset to empty lists after loading from checkpoint --- XPointMLTest.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index e46dddd..758c9e5 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -861,9 +861,6 @@ def main(): t2 = timer() print("time (s) to prepare model: " + str(t2-t1)) - train_loss = [] - val_loss = [] - num_epochs = args.epochs for epoch in range(start_epoch, num_epochs): train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device)) @@ -959,4 +956,4 @@ def main(): print("total time (s): " + str(t5-t0)) if __name__ == "__main__": - main() + main() \ No newline at end of file From 6fbc7725d207ec611e349877edd1d1da3c52a7a8 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Wed, 25 Jun 2025 15:10:33 -0400 Subject: [PATCH 3/5] fixed bugs in rotate and reflect functions --- XPointMLTest.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index 758c9e5..aa3a21f 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -77,7 +77,8 @@ def rotate(frameData,deg): plotSimple(all[2], f"{frameData['fnum']}_rotation{deg}_all2.png") plotSimple(all[3], f"{frameData['fnum']}_rotation{deg}_all3.png") - mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) + # mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) + mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.NEAREST) return { "fnum": frameData["fnum"], "rotation": deg, @@ -95,13 +96,21 @@ def reflect(frameData,axis): if axis not in [0,1]: print(f"invalid reflection axis specified... exiting") sys.exit() - psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) - all = torch.flip(frameData["all"], dims=(axis,)) + + # psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) + # all = torch.flip(frameData["all"], dims=(axis,)) + # mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) + + psi = torch.flip(frameData["psi"], dims=(axis+1,)) + all = torch.flip(frameData["all"], dims=(axis+1,)) + mask = torch.flip(frameData["mask"], dims=(axis+1,)) + plotSimple(all[0], f"{frameData['fnum']}_reflectionAxis{axis}_all0.png") plotSimple(all[1], f"{frameData['fnum']}_reflectionAxis{axis}_all1.png") plotSimple(all[2], f"{frameData['fnum']}_reflectionAxis{axis}_all2.png") plotSimple(all[3], f"{frameData['fnum']}_reflectionAxis{axis}_all3.png") - mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) + + return { "fnum": frameData["fnum"], "rotation": 0, From fc8b2bb67db8a6bb75976b7d8af4853374992219 Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Fri, 27 Jun 2025 14:09:47 -0400 Subject: [PATCH 4/5] removed commented lines of code in rotate and reflect function --- XPointMLTest.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index aa3a21f..b8f9389 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -77,7 +77,6 @@ def rotate(frameData,deg): plotSimple(all[2], f"{frameData['fnum']}_rotation{deg}_all2.png") plotSimple(all[3], f"{frameData['fnum']}_rotation{deg}_all3.png") - # mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.NEAREST) return { "fnum": frameData["fnum"], @@ -96,11 +95,7 @@ def reflect(frameData,axis): if axis not in [0,1]: print(f"invalid reflection axis specified... exiting") sys.exit() - - # psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) - # all = torch.flip(frameData["all"], dims=(axis,)) - # mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) - + psi = torch.flip(frameData["psi"], dims=(axis+1,)) all = torch.flip(frameData["all"], dims=(axis+1,)) mask = torch.flip(frameData["mask"], dims=(axis+1,)) From bfb0c98e67c8e017100e0109a8d4617ca8654ddd Mon Sep 17 00:00:00 2001 From: Swaroop Sridhar Date: Sat, 28 Jun 2025 01:03:59 -0400 Subject: [PATCH 5/5] removed plotSimple --- XPointMLTest.py | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/XPointMLTest.py b/XPointMLTest.py index b8f9389..50f4eff 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -59,25 +59,16 @@ def expand_xpoints_mask(binary_mask, kernel_size=9): return expanded_mask -def plotSimple(arr, outfile): - plt.imshow(arr, interpolation="nearest", origin="upper") - plt.colorbar() - plt.savefig(outfile) - plt.clf() - def rotate(frameData,deg): if deg not in [90, 180, 270]: print(f"invalid rotation specified... exiting") sys.exit() + psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR) all = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR) - - plotSimple(all[0], f"{frameData['fnum']}_rotation{deg}_all0.png") - plotSimple(all[1], f"{frameData['fnum']}_rotation{deg}_all1.png") - plotSimple(all[2], f"{frameData['fnum']}_rotation{deg}_all2.png") - plotSimple(all[3], f"{frameData['fnum']}_rotation{deg}_all3.png") - + # For mask, use nearest neighbor interpolation to preserve binary values mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.NEAREST) + return { "fnum": frameData["fnum"], "rotation": deg, @@ -100,12 +91,6 @@ def reflect(frameData,axis): all = torch.flip(frameData["all"], dims=(axis+1,)) mask = torch.flip(frameData["mask"], dims=(axis+1,)) - plotSimple(all[0], f"{frameData['fnum']}_reflectionAxis{axis}_all0.png") - plotSimple(all[1], f"{frameData['fnum']}_reflectionAxis{axis}_all1.png") - plotSimple(all[2], f"{frameData['fnum']}_reflectionAxis{axis}_all2.png") - plotSimple(all[3], f"{frameData['fnum']}_reflectionAxis{axis}_all3.png") - - return { "fnum": frameData["fnum"], "rotation": 0, @@ -278,10 +263,6 @@ def load(self, fnum): by_torch = torch.from_numpy(fields["By"]).float().unsqueeze(0) jz_torch = torch.from_numpy(fields["Jz"]).float().unsqueeze(0) all_torch = torch.cat((psi_torch,bx_torch,by_torch,jz_torch)) # [4, Nx, Ny] - plotSimple(all_torch[0], f"{fnum}_all0.png") - plotSimple(all_torch[1], f"{fnum}_all1.png") - plotSimple(all_torch[2], f"{fnum}_all2.png") - plotSimple(all_torch[3], f"{fnum}_all3.png") mask_torch = torch.from_numpy(binaryMap).float().unsqueeze(0) # [1, Nx, Ny] if self.verbosity > 0: