Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions NLQ/VSLNet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.nn as nn
import submitit
import wandb
from torch.utils.tensorboard.writer import SummaryWriter
from model.VSLNet import build_optimizer_and_scheduler, VSLNet
from tqdm import tqdm
Expand Down Expand Up @@ -81,10 +82,30 @@ def main(configs, parser):

writer = None
if configs.log_to_tensorboard is not None:
assert not configs.log_to_wandb
log_dir = os.path.join(configs.tb_log_dir, configs.log_to_tensorboard)
os.makedirs(log_dir, exist_ok=True)
print(f"Writing to tensorboard: {log_dir}")
writer = SummaryWriter(log_dir=log_dir)
elif configs.log_to_wandb:
wandb.init(
# set the wandb project where this run will be logged
project="ego4d-nlq",
name=configs.task,

# track hyperparameters and run metadata
config={
"architecture": "VSLNet",
"task_name": configs.task,
"clip_norm": configs.clip_norm,
"epochs": configs.epochs,
"init_lr": configs.init_lr,
"batch_size": configs.batch_size,
"max_pos_len": configs.max_pos_len,
}
)



# train and test
if configs.mode.lower() == "train":
Expand All @@ -109,6 +130,7 @@ def main(configs, parser):
)
print("start training...", flush=True)
global_step = 0
num_examples_seen = 0
for epoch in range(configs.epochs):
model.train()
for data in tqdm(
Expand All @@ -127,6 +149,8 @@ def main(configs, parser):
e_labels,
h_labels,
) = data
num_examples_seen += vfeats.shape[0]

# prepare features
vfeats, vfeat_lens = vfeats.to(device), vfeat_lens.to(device)
s_labels, e_labels, h_labels = (
Expand Down Expand Up @@ -174,11 +198,20 @@ def main(configs, parser):
optimizer.step()
scheduler.step()
if writer is not None and global_step % configs.tb_log_freq == 0:
writer.add_scalar("Loss/Total", total_loss.detach().cpu(), global_step)
writer.add_scalar("Loss/Loc", loc_loss.detach().cpu(), global_step)
writer.add_scalar("Loss/Highlight", highlight_loss.detach().cpu(), global_step)
writer.add_scalar("Loss/Highlight (*lambda)", (configs.highlight_lambda * highlight_loss.detach().cpu()), global_step)
writer.add_scalar("LR", optimizer.param_groups[0]["lr"], global_step)
writer.add_scalar("Loss/Total", total_loss.detach().cpu(), num_examples_seen)
writer.add_scalar("Loss/Loc", loc_loss.detach().cpu(), num_examples_seen)
writer.add_scalar("Loss/Highlight", highlight_loss.detach().cpu(), num_examples_seen)
writer.add_scalar("Loss/Highlight (*lambda)", (configs.highlight_lambda * highlight_loss.detach().cpu()), num_examples_seen)
writer.add_scalar("LR", optimizer.param_groups[0]["lr"], num_examples_seen)
elif configs.log_to_wandb and global_step % configs.tb_log_freq == 0:
to_add = {
"Loss/Total": total_loss.detach().cpu(),
"Loss/Loc": loc_loss.detach().cpu(),
"Loss/Highlight": highlight_loss.detach().cpu(),
"Loss/Highlight (*lambda)": (configs.highlight_lambda * highlight_loss.detach().cpu()),
"LR": optimizer.param_groups[0]["lr"],
}
wandb.log(to_add, step=num_examples_seen)

# evaluate
if (
Expand Down Expand Up @@ -209,6 +242,12 @@ def main(configs, parser):
for name, value in score_dict.items():
kk = name.replace("\n", " ")
writer.add_scalar(f"Val/{kk}", value, global_step)
if configs.log_to_wandb:
vals_to_add = {}
for name, value in score_dict.items():
kk = name.replace("\n", " ")
vals_to_add[f"Val/{kk}"] = value
wandb.log(vals_to_add, step=num_examples_seen)

score_writer.write(score_str)
score_writer.flush()
Expand All @@ -227,6 +266,8 @@ def main(configs, parser):
model.train()

score_writer.close()
if configs.log_to_wandb:
wandb.finish()

elif configs.mode.lower() == "test":
if not os.path.exists(model_dir):
Expand Down
6 changes: 6 additions & 0 deletions NLQ/VSLNet/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def read_command_line():
default=None,
help="set to the last `_xxx` in ckpt repo to eval results",
)
parser.add_argument(
"--log_to_wandb",
action="store_true",
default=False,
help="Whether to log to wandb",
)
parser.add_argument(
"--log_to_tensorboard",
type=str,
Expand Down