diff --git a/XPointMLTest.py b/XPointMLTest.py index abc706f..50f4eff 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -63,9 +63,12 @@ def rotate(frameData,deg): if deg not in [90, 180, 270]: print(f"invalid rotation specified... exiting") sys.exit() + psi = v2.functional.rotate(frameData["psi"], deg, v2.InterpolationMode.BILINEAR) all = v2.functional.rotate(frameData["all"], deg, v2.InterpolationMode.BILINEAR) - mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.BILINEAR) + # For mask, use nearest neighbor interpolation to preserve binary values + mask = v2.functional.rotate(frameData["mask"], deg, v2.InterpolationMode.NEAREST) + return { "fnum": frameData["fnum"], "rotation": deg, @@ -83,9 +86,11 @@ def reflect(frameData,axis): if axis not in [0,1]: print(f"invalid reflection axis specified... exiting") sys.exit() - psi = torch.flip(frameData["psi"][0], dims=(axis,)).unsqueeze(0) - all = torch.flip(frameData["all"], dims=(axis,)) - mask = torch.flip(frameData["mask"][0], dims=(axis,)).unsqueeze(0) + + psi = torch.flip(frameData["psi"], dims=(axis+1,)) + all = torch.flip(frameData["all"], dims=(axis+1,)) + mask = torch.flip(frameData["mask"], dims=(axis+1,)) + return { "fnum": frameData["fnum"], "rotation": 0, @@ -841,9 +846,6 @@ def main(): t2 = timer() print("time (s) to prepare model: " + str(t2-t1)) - train_loss = [] - val_loss = [] - num_epochs = args.epochs for epoch in range(start_epoch, num_epochs): train_loss.append(train_one_epoch(model, train_loader, criterion, optimizer, device)) @@ -939,4 +941,4 @@ def main(): print("total time (s): " + str(t5-t0)) if __name__ == "__main__": - main() + main() \ No newline at end of file