Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.png
*.jpg
*.mp4
*.gif
*filelist.txt
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ python upload_data.py DATA_NAME PROJECT_NAME "Sample description of uploading ru
# This should be run on a system with a GPU (e.g., our server)
# training.py, sbatch scripts, and datasets should be in the same directory on the server (could be learning)
cd learning
python training.py MODEL_NAME PROJECT_NAME "Sample description of training run..." ARCHITECTURE_NAME DATA_NAME(S) --local_data
python training.py MODEL_NAME PROJECT_NAME "Sample description of training run..." ARCHITECTURE_NAME DATA_NAME(S) --local_data --use_augmentation

# Performs inference
# This will run on a system that can run Unreal Engine
# Note: MODEL_NAME_FROM_WANDB can be found in arcslaboratory -> Projects -> PROJECT_NAME -> Artifacts
# data_augmentations.py must be in same directory as inference.py if performing inference on a model trained with data augmentation
cd learning
python inference.py INFERENCE_NAME PROJECT_NAME "Sample description of inference run..." MODEL_NAME_FROM_WANDB:VERSION IMAGE_SAVE_FOLDER_NAME
~~~
Expand Down
57 changes: 57 additions & 0 deletions download_wandb_project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import wandb
import os
import argparse

# example usage:
# python download_wandb_project.py Summer2024Official --output_dir Summer2024Official_downloads


def download_project_runs(project: str, output_dir: str):
"""
Download all runs and their artifacts from a specified WandB project.

Args:
project (str): The name of the WandB project.
output_dir (str): The directory to save downloaded runs and artifacts.
"""
api = wandb.Api() # Initialize the WandB API

runs = api.runs(f"arcslaboratory/{project}") # Get all runs for the project
print(f"Found {len(runs)} runs in project '{project}'")

os.makedirs(output_dir, exist_ok=True) # Ensure output directory exists

for run in runs:
run_dir = os.path.join(output_dir, run.id) # Directory for this run
os.makedirs(run_dir, exist_ok=True) # Create directory for the run

print(f"\nDownloading run: {run.name} ({run.id})")

# Download all files associated with the run
for file in run.files():
print(f"File: {file.name}")
file.download(root=run_dir, replace=True) # Download file to run_dir

# Download all logged artifacts for the run
for artifact in run.logged_artifacts():
artifact_name = f"{artifact.name.replace('/', '_')}:{artifact.version}" # Format artifact name
print(f"Artifact: {artifact_name}")
artifact.download(root=run_dir) # Download artifact to run_dir


if __name__ == "__main__":
# Set up command-line argument parsing
parser = argparse.ArgumentParser(
description="Download all runs and artifacts from a WandB project."
)
parser.add_argument(
"project", help="WandB project name"
) # Required project name argument
parser.add_argument(
"--output_dir",
default="wandb_downloads",
help="Directory to save the runs and artifacts",
)

args = parser.parse_args() # Parse command-line arguments
download_project_runs(args.project, args.output_dir) # Run the download function
89 changes: 89 additions & 0 deletions learning/data_augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import albumentations as A
import numpy as np
from fastai.vision.core import PILImage
from fastai.vision.augment import RandTransform


class AlbumentationsTransform(RandTransform):
"""Class that handles albumentations transformations during training."""

def __init__(self, train_aug, valid_aug=None, split_idx=None):
"""Constructor for AlbumentationsTransform."""
super().__init__() # calls base class (RandTransform) constructor
self.train_aug = train_aug
self.valid_aug = (
valid_aug or train_aug
) # defaults to training augmentations if no validation augmentations are provided
self.split_idx = split_idx # indicates whether the transform is applied to training or validation data
self.order = 2 # apply after resizing

def before_call(self, b, split_idx):
"""Called before the transform is applied to set the split index so we know if it's training or validation."""
self.idx = split_idx

def encodes(self, img: PILImage):
"""Apply the Albumentations transformations to the input image."""
aug = (
self.train_aug if self.idx == 0 else self.valid_aug
) # apply the appropriate augmentation
image = np.array(img) # albumentations works with numpy arrays
image = aug(image=image)[
"image"
] # extract the image from the augmentation result
return PILImage.create(image) # convert back to PILImage for compatibility


def get_train_aug():
"""Data augmentations applied to training data."""
return A.Compose(
[
A.Affine(
scale=(0.9, 1.1), # scale by 90%-110% of original size
translate_percent=0.1, # shift horizontally or vertically by up to 10% of its width/height
rotate=(-10, 10), # rotate between -10 and 10 degrees
p=0.5, # 50 chance to apply affine transformations
),
A.RandomBrightnessContrast(
p=0.2 # 20% chance to adjust brightness and contrast
),
# possible augmentations to add:
# A.Perspective(
# scale=(0.05, 0.1), # apply perspective transformation with a scale factor between 5% and 10%
# p=0.5 # 50% chance to apply perspective transformation
# ),
# A.HueSaturationValue(
# hue_shift_limit=20, # shift hue by up to 20 degrees
# sat_shift_limit=20, # shift saturation by up to 20%
# val_shift_limit=20, # shift value by up to 20%
# p=0.5 # 50% chance to apply hue, saturation, and value adjustments
# ),
# A.RandomGamma(
# gamma_limit=(80, 120), # adjust gamma between 80% and 120%
# p=0.5 # 50% chance to apply gamma adjustment
# ),
# A.RGBShift(
# r_shift_limit=20, # shift red channel by up to 20
# g_shift_limit=20, # shift green channel by up to 20
# b_shift_limit=20, # shift blue channel by up to 20
# p=0.5 # 50% chance to apply RGB shift
# ),
# A.MotionBlur(
# blur_limit=(3, 7), # apply motion blur with a kernel size between 3 and 7
# p=0.5 # 50% chance to apply motion blur
# ),
# A.GaussianNoise(
# var_limit=(10, 50), # add Gaussian noise with a variance between 10 and 50
# p=0.5 # 50% chance to apply Gaussian noise
# ),
# A.OpticalDistortion(
# distort_limit=0.05, # apply optical distortion with a limit of 5%
# shift_limit=0.05, # shift the image by up to 5%
# p=0.5 # 50% chance to apply optical distortion
# ),
]
)


def get_valid_aug():
"""Data augmentations applied to validation data (none)."""
return A.Compose([])
58 changes: 49 additions & 9 deletions learning/training/training.py → learning/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from fastai.vision.utils import get_image_files
from torch import nn

from fastai.vision.all import aug_transforms, Normalize, imagenet_stats

def parse_args() -> Namespace:
arg_parser = ArgumentParser("Train command classification networks.")
Expand Down Expand Up @@ -70,6 +71,11 @@ def parse_args() -> Namespace:
)

# Training configuration
arg_parser.add_argument(
"--use_augmentation",
action="store_true",
help="Enable data augmentation if included (default is off).",
)
arg_parser.add_argument(
"--num_epochs", type=int, default=10, help="Number of training epochs."
)
Expand Down Expand Up @@ -130,15 +136,31 @@ def y_from_filename(rotation_threshold: float, filename: str) -> str:

Example: "path/to/file/001_000011_-1p50.png" --> "right"
"""
filename_stem = Path(filename).stem
angle = float(filename_stem.split("_")[2].replace("p", "."))

if angle > rotation_threshold:
return "left"
elif angle < -rotation_threshold:
return "right"
else:
return "forward"
path = Path(filename)
filename_stem = path.stem
parts = filename_stem.split("_")
direction_keywords = {"left", "right", "forward"}

# Case 1: filename starts with a known direction
if parts[0].lower() in direction_keywords:
return parts[0].lower()

# Case 2: try to parse angle from third underscore-separated part
if len(parts) >= 3:
try:
angle_str = parts[2].replace("p", ".")
angle = float(angle_str)
if angle > rotation_threshold:
return "left"
elif angle < -rotation_threshold:
return "right"
else:
return "forward"
except ValueError:
pass # fall through to fallback

# Fallback: get label from parent directory
return path.parent.name.lower()


def get_dls(args: Namespace, data_paths: list):
Expand All @@ -163,6 +185,24 @@ def get_dls(args: Namespace, data_paths: list):
shuffle=True,
bs=args.batch_size,
item_tfms=Resize(args.image_resize),
batch_tfms=[
*aug_transforms( # apply fastai's data augmentation transforms
size=args.image_resize, # scales images to be image_resize x image_resize
flip_vert=False, # vertical flip is not used
max_rotate=10.0, # rotate images by up to 10 degrees
min_zoom=0.9, # zoom images down to 90% of their original size
max_zoom=1.1, # zoom images up to 110% of their original size
max_lighting=0.2, # adjust lighting by up to 20%
max_warp=0.2, # warp images by up to 20%
p_affine=0.5, # probability of applying affine transformations (rotation, zoom, warp)
p_lighting=0.2, # probability of applying lighting adjustments
),
Normalize.from_stats(
*imagenet_stats
), # normalize images using ImageNet statistics
]
if args.use_augmentation
else None,
)


Expand Down