Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions defaults.ini
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ num_heads = 8
# number of quantizers
num_quantizers = 1

# number of heads for the memcodes
num_heads = 8

# If true training data is kept in RAM
cache_training_data = False

Expand All @@ -70,7 +67,6 @@ random_crop = False
# normalize input audio?
norm_inputs = False


# for jukebox imbeddings. 0 (high res), 1 (med), or 2 (low res)
jukebox_layer = 0

18 changes: 15 additions & 3 deletions train_icebox.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@
from pathlib import Path

import sys
import os
import torch
from torch import optim, nn
from torch.nn import functional as F
from torch.utils import data
import torch.distributed as dist
from tqdm import trange
import pytorch_lightning as pl
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.plugins import DDPPlugin

from einops import rearrange
from pprint import pprint

import torchaudio
from fairscale.nn import checkpoint_wrapper, auto_wrap, wrap
Expand All @@ -37,7 +40,7 @@
from jukebox.make_models import make_vqvae, make_prior, MODELS, make_model
from jukebox.hparams import Hyperparams, setup_hparams
#from jukebox.sample import sample_single_window, _sample, sample_partial_window, upsample
from jukebox.utils.dist_utils import setup_dist_from_mpi
#from jukebox.utils.dist_utils import setup_dist_from_mpi
#from jukebox.utils.torch_utils import empty_cache


Expand Down Expand Up @@ -117,8 +120,11 @@ def __init__(self, global_args):
self.num_quantizers = global_args.num_quantizers
self.ema_decay = global_args.ema_decay

rank, local_rank, device = setup_dist_from_mpi()
#rank, local_rank, device = self.local_rank, self.local_rank, self.device #TODO only works on 1 pod
#rank, local_rank, device = setup_dist_from_mpi()
rank, local_rank, device = int(os.getenv('RANK')), int(os.getenv('LOCAL_RANK')), self.device
dist_url = "tcp://127.0.0.1:9500"
dist.init_process_group(backend="nccl")

self.hps = Hyperparams()
assert global_args.sample_rate == 44100, "Jukebox was pretrained at 44100 Hz."
self.hps.sr = global_args.sample_rate #44100
Expand Down Expand Up @@ -303,6 +309,7 @@ def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
torch.manual_seed(args.seed)
#dist.init_process_group(backend="nccl")

train_set = SampleDataset([args.training_dir], args)
train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True,
Expand All @@ -317,6 +324,11 @@ def main():
wandb_logger.watch(module)
push_wandb_config(wandb_logger, args)

#print(os.environ)
#for env in ['MASTER_ADDR','MASTER_PORT','RANK','LOCAL_RANK','WORLD_SIZE','GLOBAL_RANK']:
# env_val = os.getenv(env)
# print(f"{env}={env_val}")

trainer = pl.Trainer(
gpus=args.num_gpus,
accelerator="gpu",
Expand Down
Loading