+# Environment Setup (Linux)
-### Dependencies
-python 2.7
+## Install conda (if not available)
-Pytorch
+- `git clone https://github.com/Redcof/StackGAN-Pytorch.git`
+- `wget https://repo.anaconda.com/miniconda/Miniconda3-py38_23.3.1-0-Linux-x86_64.sh`
+- `bash Miniconda3-py38_23.3.1-0-Linux-x86_64.sh -b`
+- `$HOME/miniconda3/bin/conda init`
+- `source $HOME/.bashrc`
-In addition, please add the project folder to PYTHONPATH and `pip install` the following packages:
-- `tensorboard`
-- `python-dateutil`
-- `easydict`
-- `pandas`
-- `torchfile`
+## Create environment
+- `conda create -n ganenv python=3.8`
+- `conda activate ganenv`
+## Install dependencies
-**Data**
+- `pip install -r requirements.txt`
+- `conda install -c conda-forge fasttext`
+- `conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia`
-1. Download our preprocessed char-CNN-RNN text embeddings for [training coco](https://drive.google.com/open?id=0B3y_msrWZaXLQXVzOENCY2E3TlU) and [evaluating coco](https://drive.google.com/open?id=0B3y_msrWZaXLeEs5MTg0RC1fa0U), save them to `data/coco`.
- - [Optional] Follow the instructions [reedscot/icml2016](https://github.com/reedscot/icml2016) to download the pretrained char-CNN-RNN text encoders and extract text embeddings.
-2. Download the [coco](http://cocodataset.org/#download) image data. Extract them to `data/coco/`.
+## Install CUDA drivers(if not available)
+**How to check?**
+```cmd
+python cuda_test.py # should return True
+```
-**Training**
-- The steps to train a StackGAN model on the COCO dataset using our preprocessed embeddings.
- - Step 1: train Stage-I GAN (e.g., for 120 epochs) `python main.py --cfg cfg/coco_s1.yml --gpu 0`
- - Step 2: train Stage-II GAN (e.g., for another 120 epochs) `python main.py --cfg cfg/coco_s2.yml --gpu 1`
-- `*.yml` files are example configuration files for training/evaluating our models.
-- If you want to try your own datasets, [here](https://github.com/soumith/ganhacks) are some good tips about how to train GAN. Also, we encourage to try different hyper-parameters and architectures, especially for more complex datasets.
+**Check OS architecture**
+`cat /etc/os-release` return the OS name and `uname -m` command should return the OS architecture. For us, it was 'x86_64'
+**Downloading Toolkit**
+[https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux](https://developer.nvidia.com/cuda-11-7-0-download-archive?target_os=Linux)
+We choose to install online:
+```commandline
+sudo dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
+sudo dnf clean all
+sudo dnf -y module install nvidia-driver:latest-dkms
+sudo dnf -y install cuda
+```
-**Pretrained Model**
-- [StackGAN for coco](https://drive.google.com/open?id=0B3y_msrWZaXLYjNra2ZSSmtVQlE). Download and save it to `models/coco`.
-- **Our current implementation has a higher inception score(10.62±0.19) than reported in the StackGAN paper**
+**Data - Text**
+
+1. Download our preprocessed char-CNN-RNN text embeddings
+ for [training coco](https://drive.google.com/open?id=0B3y_msrWZaXLQXVzOENCY2E3TlU)
+ and [evaluating coco](https://drive.google.com/open?id=0B3y_msrWZaXLeEs5MTg0RC1fa0U), save them to `data/coco`.
+
+2. [Optional] Follow the instructions [reedscot/icml2016](https://github.com/reedscot/icml2016) to download the
+ pretrained char-CNN-RNN text encoders and extract text embeddings.
+
+**Data - Image**
+1. Download the [coco](http://cocodataset.org/#download) image data. Extract them to `data/coco/`.
+**Custom Dataset**
+
+1. See `data/README.md` file
+
+**Training COCO**
+
+- The steps to train a StackGAN model on the COCO dataset using our preprocessed embeddings.
+ - Step 1: train Stage-I GAN (e.g., for 120 epochs) `python code/main.py --cfg cfg/coco_s1.yml --gpu 0`
+ - Step 2: train Stage-II GAN (e.g., for another 120 epochs) `python code/main.py --cfg cfg/coco_s2.yml --gpu 1`
+- `*.yml` files are example configuration files for training/evaluating our models.
+- If you want to try your own datasets, [here](https://github.com/soumith/ganhacks) are some good tips about how to
+ train GAN. Also, we encourage to try different hyper-parameters and architectures, especially for more complex
+ datasets.
+
+**Pretrained Model**
+
+- [StackGAN for coco](https://drive.google.com/open?id=0B3y_msrWZaXLYjNra2ZSSmtVQlE). Download and save it
+ to `models/coco`.
+- **Our current implementation has a higher inception score(10.62±0.19) than reported in the StackGAN paper**
**Evaluating**
-- Run `python main.py --cfg cfg/coco_eval.yml --gpu 2` to generate samples from captions in COCO validation set.
+
+- Run `python code/main.py --cfg cfg/coco_eval.yml --gpu 2` to generate samples from captions in COCO validation set.
Examples for COCO:
-
+


-Save your favorite pictures generated by our models since the randomness from noise z and conditioning augmentation makes them creative enough to generate objects with different poses and viewpoints from the same discription :smiley:
-
-
+Save your favorite pictures generated by our models since the randomness from noise z and conditioning augmentation
+makes them creative enough to generate objects with different poses and viewpoints from the same discription :smiley:
### Citing StackGAN
+
If you find StackGAN useful in your research, please consider citing:
```
@@ -71,14 +116,14 @@ booktitle = {{ICCV}},
}
```
-
**Our follow-up work**
- [StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks](https://arxiv.org/abs/1710.10916)
- [AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks](https://arxiv.org/abs/1711.10485) [[supplementary]](https://1drv.ms/b/s!Aj4exx_cRA4ghK5-kUG-EqH7hgknUA)[[code]](https://github.com/taoxugit/AttnGAN)
-
**References**
-- Generative Adversarial Text-to-Image Synthesis [Paper](https://arxiv.org/abs/1605.05396) [Code](https://github.com/reedscot/icml2016)
-- Learning Deep Representations of Fine-grained Visual Descriptions [Paper](https://arxiv.org/abs/1605.05395) [Code](https://github.com/reedscot/cvpr2016)
+- Generative Adversarial Text-to-Image
+ Synthesis [Paper](https://arxiv.org/abs/1605.05396) [Code](https://github.com/reedscot/icml2016)
+- Learning Deep Representations of Fine-grained Visual
+ Descriptions [Paper](https://arxiv.org/abs/1605.05395) [Code](https://github.com/reedscot/cvpr2016)
diff --git a/TODO.txt b/TODO.txt
new file mode 100644
index 0000000..251db6b
--- /dev/null
+++ b/TODO.txt
@@ -0,0 +1,7 @@
+May-25
+[+] Share image with - Tejas
+[*] Generate data for testing and test it
+[ ] Image superresolution with S1 images
+[ ] Run training with fasttext self-trained + latest captions
+[ ] VQGAN
+[ ] CLIP
diff --git a/code/__pycache__/model.cpython-38.pyc b/code/__pycache__/model.cpython-38.pyc
new file mode 100644
index 0000000..726f2db
Binary files /dev/null and b/code/__pycache__/model.cpython-38.pyc differ
diff --git a/code/__pycache__/trainer.cpython-38.pyc b/code/__pycache__/trainer.cpython-38.pyc
new file mode 100644
index 0000000..21e684e
Binary files /dev/null and b/code/__pycache__/trainer.cpython-38.pyc differ
diff --git a/code/cfg/sixray_500_s1.yml b/code/cfg/sixray_500_s1.yml
new file mode 100644
index 0000000..6c6daae
--- /dev/null
+++ b/code/cfg/sixray_500_s1.yml
@@ -0,0 +1,34 @@
+CONFIG_NAME: 'stage_1'
+
+DATASET_NAME: 'sixray_2381_charcnnrnn_1536_jemb'
+EMBEDDING_TYPE: 'embedding_bulk_word_1536_jemb.pickle'
+GPU_ID: '0,1'
+Z_DIM: 200
+DATA_DIR: 'data/sixray_2381'
+IMSIZE: 64
+WORKERS: 4
+STAGE: 1
+TRAIN:
+ FLAG: True
+ BATCH_SIZE: 64
+ BATCH_DROP_LAST: True
+ MAX_EPOCH: 1000
+ LR_DECAY_EPOCH: 20
+ SNAPSHOT_INTERVAL: 5
+ DISCRIMINATOR_LR: 0.0004
+ GENERATOR_LR: 0.0004
+ COEFF:
+ KL: 2.0
+ FINETUNE:
+ FLAG: False
+ EPOCH_START: 1001
+ NET_G: '/home/icmore_acc/Downloads/StackGAN-Pytorch/output/sixray_500_stage_1_2023_05_19_15_38_51/Model/netG_epoch_300.pth'
+ NET_D: '/home/icmore_acc/Downloads/StackGAN-Pytorch/output/sixray_500_stage_1_2023_05_19_15_38_51/Model/netD_epoch_last.pth'
+
+GAN:
+ CONDITION_DIM: 128
+ DF_DIM: 96
+ GF_DIM: 192
+
+TEXT:
+ DIMENSION: 1536
diff --git a/code/cfg/sixray_500_s2.yml b/code/cfg/sixray_500_s2.yml
new file mode 100644
index 0000000..6b3effd
--- /dev/null
+++ b/code/cfg/sixray_500_s2.yml
@@ -0,0 +1,36 @@
+CONFIG_NAME: 'stage_2'
+
+DATASET_NAME: 'sixray_2381_charcnnrnn_1536_jemb'
+EMBEDDING_TYPE: 'embedding_bulk_word_1536_jemb.pickle'
+GPU_ID: '0,1'
+Z_DIM: 200
+DATA_DIR: 'data/sixray_2381'
+IMSIZE: 256
+WORKERS: 4
+STAGE: 2
+STAGE1_G: ''
+TRAIN:
+ FLAG: True
+ BATCH_SIZE: 64
+ BATCH_DROP_LAST: True
+ MAX_EPOCH: 1000
+ LR_DECAY_EPOCH: 20
+ SNAPSHOT_INTERVAL: 5
+ DISCRIMINATOR_LR: 0.0008
+ GENERATOR_LR: 0.0008
+ COEFF:
+ KL: 1.0
+ FINETUNE:
+ FLAG: False
+ EPOCH_START: 1001
+ NET_G: 'output/experiment15/sixray_2381_ftt_1024D_cbow_nocrop_batch_double_stage_2_train_2023_06_12_18_16_28/Model/netG_epoch_1000.pth'
+ NET_D: 'output/experiment15/sixray_2381_ftt_1024D_cbow_nocrop_batch_double_stage_2_train_2023_06_12_18_16_28/Model/netD_epoch_last.pth'
+
+GAN:
+ CONDITION_DIM: 128
+ DF_DIM: 96
+ GF_DIM: 192
+ R_NUM: 2
+
+TEXT:
+ DIMENSION: 1536
diff --git a/code/cfg/sixray_s1.yml b/code/cfg/sixray_s1.yml
new file mode 100644
index 0000000..cf7ef28
--- /dev/null
+++ b/code/cfg/sixray_s1.yml
@@ -0,0 +1,34 @@
+CONFIG_NAME: 'stage1'
+
+DATASET_NAME: 'sixray_sample'
+EMBEDDING_TYPE: 'embeddings_cc.en.300.bin_300D.pickle'
+GPU_ID: '0,1'
+Z_DIM: 100
+DATA_DIR: '../data/sixray_sample'
+IMSIZE: 64
+WORKERS: 4
+STAGE: 1
+TRAIN:
+ FLAG: True
+ BATCH_SIZE: 6
+ BATCH_DROP_LAST: True
+ MAX_EPOCH: 300
+ LR_DECAY_EPOCH: 20
+ SNAPSHOT_INTERVAL: 10
+ DISCRIMINATOR_LR: 0.0002
+ GENERATOR_LR: 0.0002
+ COEFF:
+ KL: 2.0
+ FINETUNE:
+ FLAG: False
+ EPOCH_START: 0
+ NET_G: ''
+ NET_D: ''
+
+GAN:
+ CONDITION_DIM: 128
+ DF_DIM: 96
+ GF_DIM: 192
+
+TEXT:
+ DIMENSION: 300
diff --git a/code/cfg/sixray_s2.yml b/code/cfg/sixray_s2.yml
new file mode 100644
index 0000000..dde8b3e
--- /dev/null
+++ b/code/cfg/sixray_s2.yml
@@ -0,0 +1,36 @@
+CONFIG_NAME: 'stage2'
+
+DATASET_NAME: 'sixray_sample'
+EMBEDDING_TYPE: 'embeddings_cc.en.300.bin_300D.pickle'
+GPU_ID: '0,1'
+Z_DIM: 100
+STAGE1_G: 'output/sixray_sample_stage1_2023_05_12_19_17_04/Model/netG_epoch_300.pth'
+DATA_DIR: 'data/sixray_sample'
+WORKERS: 4
+IMSIZE: 256
+STAGE: 2
+TRAIN:
+ FLAG: True
+ BATCH_SIZE: 6
+ BATCH_DROP_LAST: True
+ MAX_EPOCH: 500
+ LR_DECAY_EPOCH: 20
+ SNAPSHOT_INTERVAL: 5
+ DISCRIMINATOR_LR: 0.0002
+ GENERATOR_LR: 0.0002
+ COEFF:
+ KL: 2.0
+ FINETUNE:
+ FLAG: False
+ EPOCH_START: 0
+ NET_G: ''
+ NET_D: ''
+
+GAN:
+ CONDITION_DIM: 128
+ DF_DIM: 96
+ GF_DIM: 192
+ R_NUM: 2
+
+TEXT:
+ DIMENSION: 300
diff --git a/code/main.py b/code/main.py
index 21fecc0..2ccb3e0 100644
--- a/code/main.py
+++ b/code/main.py
@@ -1,5 +1,10 @@
from __future__ import print_function
-import torch.backends.cudnn as cudnn
+
+import pathlib
+import shutil
+
+import PIL
+import numpy as np
import torch
import torchvision.transforms as transforms
@@ -9,16 +14,17 @@
import sys
import pprint
import datetime
-import dateutil
import dateutil.tz
-
+from PIL.Image import Image
+from git import Repo
+from torch.utils.data import DataLoader
+from torchvision.transforms.transforms import _setup_size
dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))
sys.path.append(dir_path)
from miscc.datasets import TextDataset
from miscc.config import cfg, cfg_from_file
-from miscc.utils import mkdir_p
from trainer import GANTrainer
@@ -26,13 +32,69 @@ def parse_args():
parser = argparse.ArgumentParser(description='Train a GAN network')
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
- default='birds_stage1.yml', type=str)
- parser.add_argument('--gpu', dest='gpu_id', type=str, default='0')
+ default='cfg/coco_s1.yml', type=str)
+ parser.add_argument('--test_phase', dest='test_phase', default=False, action='store_true')
+ parser.add_argument('--NET_G', dest='NET_G', default='', help="Path to generator for testing")
+ parser.add_argument('--NET_D', dest='NET_D', default='', help="Path to discriminator for testing")
+ parser.add_argument('--gpu', dest='gpu_id', type=str, default='0')
parser.add_argument('--data_dir', dest='data_dir', type=str, default='')
- parser.add_argument('--manualSeed', type=int, help='manual seed')
+ parser.add_argument('--STAGE1_G', dest='STAGE1_G', type=str, default='')
+ parser.add_argument('--manualSeed', type=int, help='manual seed', default=47)
args = parser.parse_args()
return args
+
+class AspectResize(torch.nn.Module):
+ """
+ Resize image while keeping the aspect ratio.
+ Extra parts will be covered with 255(white) color value
+ """
+
+ def __init__(self, size, background=255):
+ super().__init__()
+ self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
+ self.background = background
+
+ @staticmethod
+ def fit_image_to_canvas(image: Image, canvas_width, canvas_height, background=255) -> Image:
+ # Get the dimensions of the image
+ image_width, image_height = image.size
+
+ # Calculate the aspect ratio of the image
+ image_aspect_ratio = image_width / float(image_height)
+
+ # Calculate the aspect ratio of the canvas
+ canvas_aspect_ratio = canvas_width / float(canvas_height)
+
+ # Calculate the new dimensions of the image to fit the canvas
+ if canvas_aspect_ratio > image_aspect_ratio:
+ new_width = canvas_height * image_aspect_ratio
+ new_height = canvas_height
+ else:
+ new_width = canvas_width
+ new_height = canvas_width / image_aspect_ratio
+
+ # Resize the image to the new dimensions
+ image = image.resize((int(new_width), int(new_height)), PIL.Image.BICUBIC)
+
+ # Create a blank canvas of the specified size
+ canvas = np.zeros((int(canvas_height), int(canvas_width), 3), dtype=np.uint8)
+ canvas[:, :, :] = background
+
+ # Calculate the position to paste the resized image on the canvas
+ x = int((canvas_width - new_width) / 2)
+ y = int((canvas_height - new_height) / 2)
+
+ # Paste the resized image onto the canvas
+ canvas[y:y + int(new_height), x:x + int(new_width)] = np.array(image)
+
+ return PIL.Image.fromarray(canvas)
+
+ def forward(self, image: Image) -> Image:
+ image = self.fit_image_to_canvas(image, self.size[0], self.size[1], self.background)
+ return image
+
+
if __name__ == "__main__":
args = parse_args()
if args.cfg_file is not None:
@@ -41,37 +103,109 @@ def parse_args():
cfg.GPU_ID = args.gpu_id
if args.data_dir != '':
cfg.DATA_DIR = args.data_dir
+ if args.STAGE1_G != '':
+ cfg.STAGE1_G = args.STAGE1_G
+ if args.test_phase:
+ cfg.TRAIN.FLAG = False
+ if args.NET_G:
+ cfg.TRAIN.FINETUNE.FLAG = True
+ cfg.TRAIN.FINETUNE.NET_G = args.NET_G
+ cfg.TRAIN.FINETUNE.NET_D = args.NET_D
print('Using config:')
pprint.pprint(cfg)
+ pprint.pprint(args)
+ # save git checksum
+ project_root = pathlib.Path(__file__).parents[1]
+ repo = Repo(project_root)
+ args.git_checksum = repo.git.rev_parse("HEAD") # save commit checksum
+
if args.manualSeed is None:
args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
+ phase = "test" if args.test_phase else "train"
if cfg.CUDA:
torch.cuda.manual_seed_all(args.manualSeed)
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
- output_dir = '../output/%s_%s_%s' % \
- (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
-
+ output_dir = 'output/%s_%s_%s_%s' % (cfg.DATASET_NAME, cfg.CONFIG_NAME, phase, timestamp)
+
+ if cfg.STAGE == 1:
+ # STAGE-1
+ if cfg.TRAIN.FLAG:
+ # STAGE-1 TRAINING
+ # prepare script for stage-2 training
+ with open("train_stage2.sh", "w") as fp:
+ fp.write(
+ "#!/usr/bin/bash\nsh code/miscc/cuda_mem.sh\n"
+ "python code/main.py --cfg {} --manualSeed 47 --STAGE1_G {}\n".format(
+ args.cfg_file.replace("s1", "s2"),
+ os.path.join(output_dir, "Model", "netG_epoch_{}.pth".format(cfg.TRAIN.MAX_EPOCH - 1))
+ ))
+ # prepare script for stage-1 testing
+ with open("test_stage1.sh", "w") as fp:
+ fp.write(
+ "#!/usr/bin/bash\nsh code/miscc/cuda_mem.sh\n"
+ "python code/main.py --test_phase --manualSeed 47 --cfg {} --NET_G {} --NET_D {}\n".format(
+ args.cfg_file,
+ os.path.join(output_dir, "Model", "netG_epoch_{}.pth".format(cfg.TRAIN.MAX_EPOCH)),
+ os.path.join(output_dir, "Model", "netD_epoch_last.pth"),
+ ))
+ else:
+ # STAGE-1 TESTING
+ ...
+ else:
+ # STAGE-2
+ if cfg.TRAIN.FLAG:
+ # STAGE-2 TRAINING
+ # prepare script for stage-2 testing
+ with open("test_stage2.sh", "w") as fp:
+ fp.write(
+ "#!/usr/bin/bash\nsh code/miscc/cuda_mem.sh\n"
+ "python code/main.py --test_phase --manualSeed 47 --cfg {} --NET_G {} --NET_D {}\n".format(
+ args.cfg_file,
+ os.path.join(output_dir, "Model", "netG_epoch_{}.pth".format(cfg.TRAIN.MAX_EPOCH)),
+ os.path.join(output_dir, "Model", "netD_epoch_last.pth"),
+ ))
+ else:
+ # STAGE-2 TESTING
+ ...
+
num_gpu = len(cfg.GPU_ID.split(','))
if cfg.TRAIN.FLAG:
+ # prepare image transforms
image_transform = transforms.Compose([
- transforms.RandomCrop(cfg.IMSIZE),
+ transforms.RandomCrop(cfg.IMSIZE) if False else AspectResize(cfg.IMSIZE),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
- dataset = TextDataset(cfg.DATA_DIR, 'train',
- imsize=cfg.IMSIZE,
- transform=image_transform)
- assert dataset
- dataloader = torch.utils.data.DataLoader(
- dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
- drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS))
-
+ # prepare Text caption
+ train_dataset = TextDataset(cfg.DATA_DIR, 'train',
+ imsize=cfg.IMSIZE,
+ embedding_type=cfg.EMBEDDING_TYPE,
+ transform=image_transform,
+ float_precision=32)
+ # prepare Text caption
+ test_dataset = TextDataset(cfg.DATA_DIR, 'test',
+ imsize=cfg.IMSIZE,
+ embedding_type=cfg.EMBEDDING_TYPE,
+ transform=image_transform,
+ float_precision=32)
+ print("Dataset Length:", len(train_dataset))
+ assert train_dataset
+ train_dataloader = DataLoader(dataset=train_dataset, batch_size=cfg.TRAIN.BATCH_SIZE * num_gpu,
+ drop_last=cfg.TRAIN.BATCH_DROP_LAST,
+ shuffle=True, num_workers=int(cfg.WORKERS))
+
algo = GANTrainer(output_dir)
- algo.train(dataloader, cfg.STAGE)
+ shutil.copyfile(args.cfg_file, os.path.join(output_dir, os.path.basename(args.cfg_file)))
+ with open(os.path.join(output_dir, "config.txt"), "w") as fp:
+ fp.write("%s\n" % (str(args)))
+ fp.write("%s" % (str(cfg)))
+ algo.train(train_dataloader, cfg.STAGE, test_dataset)
else:
- datapath= '%s/test/val_captions.t7' % (cfg.DATA_DIR)
- algo = GANTrainer(output_dir)
- algo.sample(datapath, cfg.STAGE)
+ datapath = os.path.join(cfg.DATA_DIR, "test", cfg.EMBEDDING_TYPE)
+ if os.path.isfile(datapath):
+ algo = GANTrainer(output_dir)
+ shutil.copyfile(args.cfg_file, os.path.join(output_dir, os.path.basename(args.cfg_file)))
+ algo.sample(datapath, output_dir, cfg.STAGE)
diff --git a/code/miscc/__pycache__/__init__.cpython-38.pyc b/code/miscc/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..c4861ed
Binary files /dev/null and b/code/miscc/__pycache__/__init__.cpython-38.pyc differ
diff --git a/code/miscc/__pycache__/config.cpython-38.pyc b/code/miscc/__pycache__/config.cpython-38.pyc
new file mode 100644
index 0000000..0693cb2
Binary files /dev/null and b/code/miscc/__pycache__/config.cpython-38.pyc differ
diff --git a/code/miscc/__pycache__/datasets.cpython-38.pyc b/code/miscc/__pycache__/datasets.cpython-38.pyc
new file mode 100644
index 0000000..58958b2
Binary files /dev/null and b/code/miscc/__pycache__/datasets.cpython-38.pyc differ
diff --git a/code/miscc/__pycache__/utils.cpython-38.pyc b/code/miscc/__pycache__/utils.cpython-38.pyc
new file mode 100644
index 0000000..c8ca6f5
Binary files /dev/null and b/code/miscc/__pycache__/utils.cpython-38.pyc differ
diff --git a/code/miscc/config.py b/code/miscc/config.py
index 666b30f..48ecbb5 100644
--- a/code/miscc/config.py
+++ b/code/miscc/config.py
@@ -1,11 +1,9 @@
from __future__ import division
from __future__ import print_function
-import os.path as osp
import numpy as np
from easydict import EasyDict as edict
-
__C = edict()
cfg = __C
@@ -27,11 +25,11 @@
__C.IMSIZE = 64
__C.STAGE = 1
-
# Training options
__C.TRAIN = edict()
__C.TRAIN.FLAG = True
__C.TRAIN.BATCH_SIZE = 64
+__C.TRAIN.BATCH_DROP_LAST = True
__C.TRAIN.MAX_EPOCH = 600
__C.TRAIN.SNAPSHOT_INTERVAL = 50
__C.TRAIN.PRETRAINED_MODEL = ''
@@ -43,6 +41,13 @@
__C.TRAIN.COEFF = edict()
__C.TRAIN.COEFF.KL = 2.0
+# To be used for resume training from a checkpoint
+__C.TRAIN.FINETUNE = edict()
+__C.TRAIN.FINETUNE.FLAG = False
+__C.TRAIN.FINETUNE.EPOCH_START = 0
+__C.TRAIN.FINETUNE.NET_G = ''
+__C.TRAIN.FINETUNE.NET_D = ''
+
# Modal options
__C.GAN = edict()
__C.GAN.CONDITION_DIM = 128
@@ -53,6 +58,27 @@
__C.TEXT = edict()
__C.TEXT.DIMENSION = 1024
+import sys
+
+
+def is_python_version(major, minor=None) -> bool:
+ """
+ Check for specific python major version and optionally minor version
+
+ Args:
+ major: int
+ minor: int [optional]
+
+ Return:
+ True is major[and minor] version matched with installed Python
+ """
+ assert isinstance(major, int)
+ if minor is None:
+ return sys.version_info[0] == major
+ else:
+ assert isinstance(minor, int)
+ return sys.version_info[0] == major and sys.version_info[1] == minor
+
def _merge_a_into_b(a, b):
"""Merge config dictionary a into config dictionary b, clobbering the
@@ -60,12 +86,18 @@ def _merge_a_into_b(a, b):
"""
if type(a) is not edict:
return
-
- for k, v in a.iteritems():
+ if is_python_version(2):
+ dict_iter = a.iteritems
+ elif is_python_version(3):
+ dict_iter = a.items
+ else:
+ return
+
+ for k, v in dict_iter():
# a must specify keys that are in b
- if not b.has_key(k):
+ if (is_python_version(2) and not b.has_key(k)) or (is_python_version(3) and k not in b):
raise KeyError('{} is not a valid config key'.format(k))
-
+
# the types must match, too
old_type = type(b[k])
if old_type is not type(v):
@@ -75,7 +107,7 @@ def _merge_a_into_b(a, b):
raise ValueError(('Type mismatch ({} vs. {}) '
'for config key: {}').format(type(b[k]),
type(v), k))
-
+
# recursively merge dicts
if type(v) is edict:
try:
@@ -91,6 +123,9 @@ def cfg_from_file(filename):
"""Load a config file and merge it into the default options."""
import yaml
with open(filename, 'r') as f:
- yaml_cfg = edict(yaml.load(f))
-
+ if is_python_version(2):
+ yaml_cfg = edict(yaml.load(f))
+ elif is_python_version(3):
+ yaml_cfg = edict(yaml.full_load(f))
+
_merge_a_into_b(yaml_cfg, __C)
diff --git a/code/miscc/cuda_mem.sh b/code/miscc/cuda_mem.sh
new file mode 100644
index 0000000..cc18208
--- /dev/null
+++ b/code/miscc/cuda_mem.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/bash
+#precision=4 # FP32- 4bytes
+#w=256
+#h=256
+#c=3
+#batch=64
+#embedding_dim=1024
+#additional=100 # 100MB
+#mb=$((($precision * (($w * $h * $c * $batch) + ($embedding_dim * $batch) + ($additional * 1048576))) / 1048576))
+# The above setting took around 450MB space
+export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:512, garbage_collection_threshold:0.8"
diff --git a/code/miscc/datasets.py b/code/miscc/datasets.py
index 477fd20..2c9b99c 100644
--- a/code/miscc/datasets.py
+++ b/code/miscc/datasets.py
@@ -3,7 +3,7 @@
from __future__ import print_function
from __future__ import unicode_literals
-
+import torch
import torch.utils.data as data
from PIL import Image
import PIL
@@ -13,14 +13,19 @@
import random
import numpy as np
import pandas as pd
-
-from miscc.config import cfg
+from torchvision.transforms import transforms
class TextDataset(data.Dataset):
def __init__(self, data_dir, split='train', embedding_type='cnn-rnn',
- imsize=64, transform=None, target_transform=None):
-
+ imsize=64, transform=None, target_transform=None, float_precision=32):
+ assert float_precision in (32, 64), "Required 32 or 64 but {} is given".format(float_precision)
+ assert split in ('train', 'test'), "Required 'train' or 'test but {} is given".format(split)
+ if float_precision == 32:
+ self.dtype = torch.float32
+ else:
+ self.dtype = torch.float64
+ self.float_precision = float_precision
self.transform = transform
self.target_transform = target_transform
self.imsize = imsize
@@ -31,13 +36,15 @@ def __init__(self, data_dir, split='train', embedding_type='cnn-rnn',
else:
self.bbox = None
split_dir = os.path.join(data_dir, split)
-
- self.filenames = self.load_filenames(split_dir)
+ self.split = split
+
self.embeddings = self.load_embedding(split_dir, embedding_type)
- self.class_id = self.load_class_id(split_dir, len(self.filenames))
# self.captions = self.load_all_captions()
-
- def get_img(self, img_path, bbox):
+ if split == "train":
+ self.filenames = self.load_filenames(split_dir)
+ self.class_id = self.load_class_id(split_dir, len(self.filenames))
+
+ def get_img(self, img_path, bbox) -> torch.Tensor:
img = Image.open(img_path).convert('RGB')
width, height = img.size
if bbox is not None:
@@ -49,12 +56,16 @@ def get_img(self, img_path, bbox):
x1 = np.maximum(0, center_x - R)
x2 = np.minimum(width, center_x + R)
img = img.crop([x1, y1, x2, y2])
- load_size = int(self.imsize * 76 / 64)
- img = img.resize((load_size, load_size), PIL.Image.BILINEAR)
+ # load_size = int(self.imsize * 76 / 64)
+ # img = img.resize((load_size, load_size), PIL.Image.BILINEAR)
if self.transform is not None:
img = self.transform(img)
- return img
-
+ else:
+ img = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])(img)
+ return img.type(self.dtype)
+
def load_bbox(self):
data_dir = self.data_dir
bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt')
@@ -70,15 +81,15 @@ def load_bbox(self):
#
filename_bbox = {img_file[:-4]: [] for img_file in filenames}
numImgs = len(filenames)
- for i in xrange(0, numImgs):
+ for i in range(0, numImgs):
# bbox = [x-left, y-top, width, height]
bbox = df_bounding_boxes.iloc[i][1:].tolist()
-
+
key = filenames[i][:-4]
filename_bbox[key] = bbox
#
return filename_bbox
-
+
def load_all_captions(self):
caption_dict = {}
for key in self.filenames:
@@ -86,7 +97,7 @@ def load_all_captions(self):
captions = self.load_captions(caption_name)
caption_dict[key] = captions
return caption_dict
-
+
def load_captions(self, caption_name):
cap_path = caption_name
with open(cap_path, "r") as f:
@@ -94,7 +105,7 @@ def load_captions(self, caption_name):
captions = [cap.replace("\ufffd\ufffd", " ")
for cap in captions if len(cap) > 0]
return captions
-
+
def load_embedding(self, data_dir, embedding_type):
if embedding_type == 'cnn-rnn':
embedding_filename = '/char-CNN-RNN-embeddings.pickle'
@@ -102,50 +113,63 @@ def load_embedding(self, data_dir, embedding_type):
embedding_filename = '/char-CNN-GRU-embeddings.pickle'
elif embedding_type == 'skip-thought':
embedding_filename = '/skip-thought-embeddings.pickle'
-
- with open(data_dir + embedding_filename, 'rb') as f:
- embeddings = pickle.load(f)
+ elif os.path.isfile(os.path.join(data_dir, embedding_type)):
+ # embeddings are provided as files
+ embedding_filename = embedding_type
+ else:
+ raise ValueError("No embedding files was found '{}'".format(embedding_type))
+
+ # > https://github.com/reedscot/icml2016
+ # > https://github.com/reedscot/cvpr2016
+ # > https://arxiv.org/pdf/1605.05395.pdf
+
+ with open(os.path.join(data_dir, embedding_filename), 'rb') as f:
+ embeddings = pickle.load(f, encoding="bytes")
embeddings = np.array(embeddings)
# embedding_shape = [embeddings.shape[-1]]
- print('embeddings: ', embeddings.shape)
- return embeddings
-
+ print('embeddings: ', embeddings.shape, "original dtype:", embeddings.dtype)
+ return torch.tensor(embeddings, dtype=self.dtype)
+
def load_class_id(self, data_dir, total_num):
- if os.path.isfile(data_dir + '/class_info.pickle'):
- with open(data_dir + '/class_info.pickle', 'rb') as f:
- class_id = pickle.load(f)
+ path_ = os.path.join(data_dir, 'class_info.pickle')
+ if os.path.isfile(path_):
+ with open(path_, 'rb') as f:
+ class_id = np.array(pickle.load(f, encoding="bytes"))
else:
class_id = np.arange(total_num)
+ print('Class_ids: ', class_id.shape, "Sample:", class_id[0])
return class_id
-
+
def load_filenames(self, data_dir):
filepath = os.path.join(data_dir, 'filenames.pickle')
with open(filepath, 'rb') as f:
filenames = pickle.load(f)
- print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
+ print('Load filenames from: %s (%d)' % (filepath, len(filenames)), "sample:", filenames[0])
return filenames
-
+
def __getitem__(self, index):
- key = self.filenames[index]
- # cls_id = self.class_id[index]
- #
- if self.bbox is not None:
- bbox = self.bbox[key]
- data_dir = '%s/CUB_200_2011' % self.data_dir
- else:
- bbox = None
- data_dir = self.data_dir
-
- # captions = self.captions[key]
+ # captions = self.captions[filepath]
embeddings = self.embeddings[index, :, :]
- img_name = '%s/images/%s.jpg' % (data_dir, key)
- img = self.get_img(img_name, bbox)
-
- embedding_ix = random.randint(0, embeddings.shape[0]-1)
+ embedding_ix = random.randint(0, embeddings.shape[0] - 1)
embedding = embeddings[embedding_ix, :]
if self.target_transform is not None:
embedding = self.target_transform(embedding)
- return img, embedding
-
+ if self.split == "train":
+ filepath = self.filenames[index]
+ # cls_id = self.class_id[index]
+ if self.bbox is not None:
+ bbox = self.bbox[filepath]
+ data_dir = '%s/CUB_200_2011' % self.data_dir
+ else:
+ bbox = None
+ data_dir = self.data_dir
+
+ img_name = os.path.join(data_dir, filepath)
+ assert os.path.isfile(img_name), img_name
+ img = self.get_img(img_name, bbox)
+ return img, embedding
+ else:
+ return embedding
+
def __len__(self):
return len(self.filenames)
diff --git a/code/miscc/utils.py b/code/miscc/utils.py
index e7ee288..e6978d3 100644
--- a/code/miscc/utils.py
+++ b/code/miscc/utils.py
@@ -33,9 +33,8 @@ def compute_discriminator_loss(netD, real_imgs, fake_imgs,
real_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
errD_real = criterion(real_logits, real_labels)
# wrong pairs
- inputs = (real_features[:(batch_size-1)], cond[1:])
- wrong_logits = \
- nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
+ inputs = (real_features[:(batch_size - 1)], cond[1:])
+ wrong_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
errD_wrong = criterion(wrong_logits, fake_labels[1:])
# fake pairs
inputs = (fake_features, cond)
@@ -43,12 +42,10 @@ def compute_discriminator_loss(netD, real_imgs, fake_imgs,
errD_fake = criterion(fake_logits, fake_labels)
if netD.get_uncond_logits is not None:
- real_logits = \
- nn.parallel.data_parallel(netD.get_uncond_logits,
- (real_features), gpus)
- fake_logits = \
- nn.parallel.data_parallel(netD.get_uncond_logits,
- (fake_features), gpus)
+ real_logits = nn.parallel.data_parallel(netD.get_uncond_logits,
+ (real_features), gpus)
+ fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits,
+ (fake_features), gpus)
uncond_errD_real = criterion(real_logits, real_labels)
uncond_errD_fake = criterion(fake_logits, fake_labels)
#
@@ -58,7 +55,7 @@ def compute_discriminator_loss(netD, real_imgs, fake_imgs,
errD_fake = (errD_fake + uncond_errD_fake) / 2.
else:
errD = errD_real + (errD_fake + errD_wrong) * 0.5
- return errD, errD_real.data[0], errD_wrong.data[0], errD_fake.data[0]
+ return errD, errD_real.data, errD_wrong.data, errD_fake.data
def compute_generator_loss(netD, fake_imgs, real_labels, conditions, gpus):
@@ -70,9 +67,8 @@ def compute_generator_loss(netD, fake_imgs, real_labels, conditions, gpus):
fake_logits = nn.parallel.data_parallel(netD.get_cond_logits, inputs, gpus)
errD_fake = criterion(fake_logits, real_labels)
if netD.get_uncond_logits is not None:
- fake_logits = \
- nn.parallel.data_parallel(netD.get_uncond_logits,
- (fake_features), gpus)
+ fake_logits = nn.parallel.data_parallel(netD.get_uncond_logits,
+ (fake_features), gpus)
uncond_errD_fake = criterion(fake_logits, real_labels)
errD_fake += uncond_errD_fake
return errD_fake
@@ -93,7 +89,7 @@ def weights_init(m):
#############################
-def save_img_results(data_img, fake, epoch, image_dir):
+def save_img_results(data_img, fake, epoch, image_dir, name_prefix='fake'):
num = cfg.VIS_COUNT
fake = fake[0:num]
# data_img is changed to [0,1]
@@ -104,12 +100,12 @@ def save_img_results(data_img, fake, epoch, image_dir):
normalize=True)
# fake.data is still [-1, 1]
vutils.save_image(
- fake.data, '%s/fake_samples_epoch_%03d.png' %
- (image_dir, epoch), normalize=True)
+ fake.data, '%s/%s_samples_epoch_%03d.png' %
+ (image_dir, name_prefix, epoch), normalize=True)
else:
vutils.save_image(
- fake.data, '%s/lr_fake_samples_epoch_%03d.png' %
- (image_dir, epoch), normalize=True)
+ fake.data, '%s/lr_%s_samples_epoch_%03d.png' %
+ (image_dir, name_prefix, epoch), normalize=True)
def save_model(netG, netD, epoch, model_dir):
@@ -118,7 +114,7 @@ def save_model(netG, netD, epoch, model_dir):
'%s/netG_epoch_%d.pth' % (model_dir, epoch))
torch.save(
netD.state_dict(),
- '%s/netD_epoch_last.pth' % (model_dir))
+ '%s/netD_epoch_last.pth' % model_dir)
print('Save G/D models')
diff --git a/code/trainer.py b/code/trainer.py
index a988206..29cdee1 100644
--- a/code/trainer.py
+++ b/code/trainer.py
@@ -1,4 +1,8 @@
from __future__ import print_function
+
+import pickle
+from pprint import pprint
+
from six.moves import range
from PIL import Image
@@ -9,9 +13,11 @@
import torch.optim as optim
import os
import time
+import gc
import numpy as np
import torchfile
+from torch.utils.tensorboard import SummaryWriter
from miscc.config import cfg
from miscc.utils import mkdir_p
@@ -21,11 +27,12 @@
from miscc.utils import compute_discriminator_loss, compute_generator_loss
from tensorboard import summary
-from tensorboard import FileWriter
+from tensorboardX import FileWriter
class GANTrainer(object):
def __init__(self, output_dir):
+ self.test_noise = None
if cfg.TRAIN.FLAG:
self.model_dir = os.path.join(output_dir, 'Model')
self.image_dir = os.path.join(output_dir, 'Image')
@@ -33,117 +40,127 @@ def __init__(self, output_dir):
mkdir_p(self.model_dir)
mkdir_p(self.image_dir)
mkdir_p(self.log_dir)
- self.summary_writer = FileWriter(self.log_dir)
-
+ print("Output:", output_dir)
+ self.summary_writer = SummaryWriter(self.log_dir)
+
self.max_epoch = cfg.TRAIN.MAX_EPOCH
self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
-
+
s_gpus = cfg.GPU_ID.split(',')
self.gpus = [int(ix) for ix in s_gpus]
self.num_gpus = len(self.gpus)
self.batch_size = cfg.TRAIN.BATCH_SIZE * self.num_gpus
torch.cuda.set_device(self.gpus[0])
cudnn.benchmark = True
-
+
# ############# For training stageI GAN #############
def load_network_stageI(self):
from model import STAGE1_G, STAGE1_D
netG = STAGE1_G()
netG.apply(weights_init)
- print(netG)
+ # print(netG)
netD = STAGE1_D()
netD.apply(weights_init)
- print(netD)
-
- if cfg.NET_G != '':
- state_dict = \
- torch.load(cfg.NET_G,
- map_location=lambda storage, loc: storage)
+ # print(netD)
+ if cfg.TRAIN.FINETUNE.FLAG:
+ assert os.path.isfile(
+ cfg.TRAIN.FINETUNE.NET_G), "TRAIN.FINETUNE.NET_G is required when TRAIN.FINETUNE.FLAG=True"
+ assert os.path.isfile(
+ cfg.TRAIN.FINETUNE.NET_D), "TRAIN.FINETUNE.NET_D is required when TRAIN.FINETUNE.FLAG=True"
+
+ state_dict = torch.load(cfg.TRAIN.FINETUNE.NET_G, map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)
- print('Load from: ', cfg.NET_G)
- if cfg.NET_D != '':
- state_dict = \
- torch.load(cfg.NET_D,
- map_location=lambda storage, loc: storage)
+ print('Load from NET_G: ', cfg.TRAIN.FINETUNE.NET_G)
+
+ state_dict = torch.load(cfg.TRAIN.FINETUNE.NET_D, map_location=lambda storage, loc: storage)
netD.load_state_dict(state_dict)
- print('Load from: ', cfg.NET_D)
+ print('Load from NET_D: ', cfg.TRAIN.FINETUNE.NET_D)
if cfg.CUDA:
netG.cuda()
netD.cuda()
return netG, netD
-
+
# ############# For training stageII GAN #############
def load_network_stageII(self):
from model import STAGE1_G, STAGE2_G, STAGE2_D
-
+
Stage1_G = STAGE1_G()
netG = STAGE2_G(Stage1_G)
netG.apply(weights_init)
- print(netG)
- if cfg.NET_G != '':
- state_dict = \
- torch.load(cfg.NET_G,
- map_location=lambda storage, loc: storage)
- netG.load_state_dict(state_dict)
- print('Load from: ', cfg.NET_G)
- elif cfg.STAGE1_G != '':
- state_dict = \
- torch.load(cfg.STAGE1_G,
- map_location=lambda storage, loc: storage)
- netG.STAGE1_G.load_state_dict(state_dict)
- print('Load from: ', cfg.STAGE1_G)
- else:
- print("Please give the Stage1_G path")
- return
-
netD = STAGE2_D()
netD.apply(weights_init)
- if cfg.NET_D != '':
- state_dict = \
- torch.load(cfg.NET_D,
- map_location=lambda storage, loc: storage)
+ # print(netG)
+ # print(netD)
+ if cfg.TRAIN.FINETUNE.FLAG:
+ assert os.path.isfile(
+ cfg.TRAIN.FINETUNE.NET_G), "TRAIN.FINETUNE.NET_G is required when TRAIN.FINETUNE.FLAG=True"
+ assert os.path.isfile(
+ cfg.TRAIN.FINETUNE.NET_D), "TRAIN.FINETUNE.NET_D is required when TRAIN.FINETUNE.FLAG=True"
+
+ state_dict = torch.load(cfg.TRAIN.FINETUNE.NET_G, map_location=lambda storage, loc: storage)
+ netG.load_state_dict(state_dict)
+ print('Load from NET_G: ', cfg.TRAIN.FINETUNE.NET_G)
+
+ state_dict = torch.load(cfg.TRAIN.FINETUNE.NET_D, map_location=lambda storage, loc: storage)
netD.load_state_dict(state_dict)
- print('Load from: ', cfg.NET_D)
- print(netD)
-
+ print('Load from NET_D: ', cfg.TRAIN.FINETUNE.NET_D)
+
+ if cfg.STAGE1_G != '' and os.path.isfile(cfg.STAGE1_G):
+ state_dict = torch.load(cfg.STAGE1_G, map_location=lambda storage, loc: storage)
+ netG.STAGE1_G.load_state_dict(state_dict)
+ print('Load from STAGE1_G: ', cfg.STAGE1_G)
+ else:
+ assert ValueError("Please give the STAGE1_G path while training Stage-2 of StackGAN")
if cfg.CUDA:
netG.cuda()
netD.cuda()
return netG, netD
-
- def train(self, data_loader, stage=1):
+
+ def train(self, data_loader, stage=1, test_dataset=None):
if stage == 1:
netG, netD = self.load_network_stageI()
else:
netG, netD = self.load_network_stageII()
-
+
nz = cfg.Z_DIM
batch_size = self.batch_size
+ self.test_noise = None
noise = Variable(torch.FloatTensor(batch_size, nz))
- fixed_noise = \
- Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1),
- volatile=True)
+ with torch.no_grad():
+ fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
if cfg.CUDA:
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
real_labels, fake_labels = real_labels.cuda(), fake_labels.cuda()
-
+
generator_lr = cfg.TRAIN.GENERATOR_LR
discriminator_lr = cfg.TRAIN.DISCRIMINATOR_LR
lr_decay_step = cfg.TRAIN.LR_DECAY_EPOCH
- optimizerD = \
- optim.Adam(netD.parameters(),
- lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
+ optimizerD = optim.Adam(netD.parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999))
netG_para = []
for p in netG.parameters():
if p.requires_grad:
netG_para.append(p)
- optimizerG = optim.Adam(netG_para,
- lr=cfg.TRAIN.GENERATOR_LR,
- betas=(0.5, 0.999))
+ optimizerG = optim.Adam(netG_para, lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999))
+ # setup epoch
+ epoch_start = 0
+ if cfg.TRAIN.FINETUNE.FLAG:
+ epoch_start = cfg.TRAIN.FINETUNE.EPOCH_START
+ ######################################################
+ # Update Double Point precision
+ ######################################################
+ dtype = data_loader.dataset.dtype
+ if torch.float64 == dtype:
+ fixed_noise = fixed_noise.type(dtype)
+ noise = noise.type(dtype)
+ real_labels = real_labels.type(dtype)
+ fake_labels = fake_labels.type(dtype)
+ netG.double()
+ netD.double()
count = 0
- for epoch in range(self.max_epoch):
+ print("Training...")
+ for epoch in range(epoch_start, self.max_epoch):
start_t = time.time()
if epoch % lr_decay_step == 0 and epoch > 0:
generator_lr *= 0.5
@@ -152,8 +169,12 @@ def train(self, data_loader, stage=1):
discriminator_lr *= 0.5
for param_group in optimizerD.param_groups:
param_group['lr'] = discriminator_lr
-
- for i, data in enumerate(data_loader, 0):
+
+ loop_ran = False
+ for batch_idx, data in enumerate(data_loader, 0):
+ print("\rEpoch: {}/{} Batch: {}/{} ".format(
+ epoch + 1, self.max_epoch, batch_idx + 1, len(data_loader)), end="\b")
+ loop_ran = True
######################################################
# (1) Prepare training data
######################################################
@@ -163,26 +184,28 @@ def train(self, data_loader, stage=1):
if cfg.CUDA:
real_imgs = real_imgs.cuda()
txt_embedding = txt_embedding.cuda()
-
- #######################################################
+
+ ######################################################
# (2) Generate fake images
######################################################
noise.data.normal_(0, 1)
inputs = (txt_embedding, noise)
- _, fake_imgs, mu, logvar = \
- nn.parallel.data_parallel(netG, inputs, self.gpus)
-
- ############################
+
+ assert len(txt_embedding.shape) == len(
+ noise.shape) == 2, "Two 2D tensors are expected, Got {} & {}".format(
+ txt_embedding.shape, noise.shape)
+ _, fake_imgs, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus)
+
+ ###########################
# (3) Update D network
###########################
netD.zero_grad()
- errD, errD_real, errD_wrong, errD_fake = \
- compute_discriminator_loss(netD, real_imgs, fake_imgs,
- real_labels, fake_labels,
- mu, self.gpus)
+ errD, errD_real, errD_wrong, errD_fake = compute_discriminator_loss(netD, real_imgs, fake_imgs,
+ real_labels, fake_labels,
+ mu, self.gpus)
errD.backward()
optimizerD.step()
- ############################
+ ###########################
# (2) Update G network
###########################
netG.zero_grad()
@@ -192,67 +215,88 @@ def train(self, data_loader, stage=1):
errG_total = errG + kl_loss * cfg.TRAIN.COEFF.KL
errG_total.backward()
optimizerG.step()
-
+
count = count + 1
- if i % 100 == 0:
- summary_D = summary.scalar('D_loss', errD.data[0])
- summary_D_r = summary.scalar('D_loss_real', errD_real)
- summary_D_w = summary.scalar('D_loss_wrong', errD_wrong)
- summary_D_f = summary.scalar('D_loss_fake', errD_fake)
- summary_G = summary.scalar('G_loss', errG.data[0])
- summary_KL = summary.scalar('KL_loss', kl_loss.data[0])
-
- self.summary_writer.add_summary(summary_D, count)
- self.summary_writer.add_summary(summary_D_r, count)
- self.summary_writer.add_summary(summary_D_w, count)
- self.summary_writer.add_summary(summary_D_f, count)
- self.summary_writer.add_summary(summary_G, count)
- self.summary_writer.add_summary(summary_KL, count)
-
+ if batch_idx % 100 == 0:
+ self.summary_writer.add_scalar('D_loss', errD.data, count)
+ self.summary_writer.add_scalar('D_loss_real', errD_real, count)
+ self.summary_writer.add_scalar('D_loss_wrong', errD_wrong, count)
+ self.summary_writer.add_scalar('D_loss_fake', errD_fake, count)
+ self.summary_writer.add_scalar('G_loss', errG.data, count)
+ self.summary_writer.add_scalar('KL_loss', kl_loss.data, count)
+ if (epoch % self.snapshot_interval == 0 or epoch == self.max_epoch - 1) and batch_idx % 100 == 0:
# save the image result for each epoch
inputs = (txt_embedding, fixed_noise)
- lr_fake, fake, _, _ = \
- nn.parallel.data_parallel(netG, inputs, self.gpus)
+ lr_fake, fake, _, _ = nn.parallel.data_parallel(netG, inputs, self.gpus)
save_img_results(real_img_cpu, fake, epoch, self.image_dir)
if lr_fake is not None:
save_img_results(None, lr_fake, epoch, self.image_dir)
+ ###########################
+ # GENERATE TEST IMAGES
+ ###########################
+ self.test(netG, test_dataset.embeddings, self.image_dir, epoch)
+
+ if loop_ran is False:
+ raise Warning(
+ "Not enough data available.\n"
+ "Reasons:\n"
+ "(1) Dataset() length=0 or \n"
+ "(2) When `drop_last=True` in Dataloader() and the `Dataset() length` < `batch-size`\n"
+ "Solutions:\n"
+ "(1) Reduce batch size to satisfy `Dataset() length` >= `batch-size`[recommended]\n"
+ "(2) Set `drop_last=False`[not recommended]")
end_t = time.time()
- print('''[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
+ print('''\n[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_KL: %.4f
Loss_real: %.4f Loss_wrong:%.4f Loss_fake %.4f
Total Time: %.2fsec
'''
- % (epoch, self.max_epoch, i, len(data_loader),
- errD.data[0], errG.data[0], kl_loss.data[0],
+ % (epoch, self.max_epoch, batch_idx, len(data_loader),
+ errD.data.item(), errG.data.item(), kl_loss.data.item(),
errD_real, errD_wrong, errD_fake, (end_t - start_t)))
+
if epoch % self.snapshot_interval == 0:
save_model(netG, netD, epoch, self.model_dir)
+
+ # CLEAN GPU RAM ########################
+ del real_imgs
+ del txt_embedding
+ del inputs
+ del _
+ del fake_imgs
+ del mu
+ del logvar
+ del errD
+ del errD_real
+ del errD_wrong
+ del errD_fake
+ del kl_loss
+ del errG_total
+ # Fix: https://discuss.pytorch.org/t/how-to-totally-free-allocate-memory-in-cuda/79590
+ torch.cuda.empty_cache()
+ gc.collect()
+ print("memory_allocated(GB): ", torch.cuda.memory_allocated() / 1e9)
+ print("memory_cached(GB): ", torch.cuda.memory_reserved() / 1e9)
+ # CLEAN GPU RAM ########################
#
save_model(netG, netD, self.max_epoch, self.model_dir)
#
+ self.summary_writer.flush()
self.summary_writer.close()
-
- def sample(self, datapath, stage=1):
- if stage == 1:
- netG, _ = self.load_network_stageI()
- else:
- netG, _ = self.load_network_stageII()
+
+ def test(self, netG, embeddings, output_dir, epoch):
netG.eval()
-
- # Load text embeddings generated from the encoder
- t_file = torchfile.load(datapath)
- captions_list = t_file.raw_txt
- embeddings = np.concatenate(t_file.fea_txt, axis=0)
- num_embeddings = len(captions_list)
- print('Successfully load sentences from: ', datapath)
+ num_embeddings = len(embeddings)
print('Total number of sentences:', num_embeddings)
print('num_embeddings:', num_embeddings, embeddings.shape)
# path to save generated samples
- save_dir = cfg.NET_G[:cfg.NET_G.find('.pth')]
+ save_dir = output_dir + "/generated"
mkdir_p(save_dir)
-
batch_size = np.minimum(num_embeddings, self.batch_size)
nz = cfg.Z_DIM
- noise = Variable(torch.FloatTensor(batch_size, nz))
+ if self.test_noise is None:
+ self.test_noise = Variable(torch.FloatTensor(batch_size, nz))
+ self.test_noise.data.normal_(0, 1)
+ noise = self.test_noise
if cfg.CUDA:
noise = noise.cuda()
count = 0
@@ -268,23 +312,53 @@ def sample(self, datapath, stage=1):
txt_embedding = Variable(torch.FloatTensor(embeddings_batch))
if cfg.CUDA:
txt_embedding = txt_embedding.cuda()
-
+
#######################################################
# (2) Generate fake images
######################################################
- noise.data.normal_(0, 1)
inputs = (txt_embedding, noise)
- _, fake_imgs, mu, logvar = \
- nn.parallel.data_parallel(netG, inputs, self.gpus)
- for i in range(batch_size):
- save_name = '%s/%d.png' % (save_dir, count + i)
- im = fake_imgs[i].data.cpu().numpy()
- im = (im + 1.0) * 127.5
- im = im.astype(np.uint8)
- # print('im', im.shape)
- im = np.transpose(im, (1, 2, 0))
- # print('im', im.shape)
- im = Image.fromarray(im)
- im.save(save_name)
+ assert len(txt_embedding.shape) == len(noise.shape) == 2, "2D tensors are expected, Got {} & {}".format(
+ txt_embedding.shape, noise.shape)
+ _, fake_imgs, mu, logvar = nn.parallel.data_parallel(netG, inputs, self.gpus)
+ save_img_results(None, fake_imgs, epoch, save_dir, name_prefix="test")
+ # for i in range(batch_size):
+ # save_name = '%s/%d.png' % (save_dir, count + i)
+ # im = fake_imgs[i].data.cpu().numpy()
+ # im = (im + 1.0) * 127.5
+ # im = im.astype(np.uint8)
+ # # print('im', im.shape)
+ # im = np.transpose(im, (1, 2, 0))
+ # # print('im', im.shape)
+ # im = Image.fromarray(im)
+ # im.save(save_name)
count += batch_size
-
+ # CLEAN GPU RAM ########################
+ del txt_embedding
+ del inputs
+ del _
+ del fake_imgs
+ del mu
+ del logvar
+ del batch_size
+ del embeddings_batch
+ # Fix: https://discuss.pytorch.org/t/how-to-totally-free-allocate-memory-in-cuda/79590
+ torch.cuda.empty_cache()
+ gc.collect()
+ # CLEAN GPU RAM ########################
+ netG.train()
+
+ def sample(self, datapath, output_dir, stage):
+ if stage == 1:
+ netG, _ = self.load_network_stageI()
+ elif stage == 2:
+ netG, _ = self.load_network_stageII()
+ else:
+ raise ValueError("Stage must me 1 or 2 but {} given".format(stage))
+
+ # Load text embeddings generated from the encoder
+ with open(datapath, 'rb') as f:
+ embeddings = pickle.load(f, encoding="bytes")
+ embeddings = np.array(embeddings)
+ # embedding_shape = [embeddings.shape[-1]]
+ print('test data embeddings: ', embeddings.shape)
+ print('Successfully load sentences from: ', datapath)
diff --git a/create_data.sh b/create_data.sh
new file mode 100644
index 0000000..9a34ac1
--- /dev/null
+++ b/create_data.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/bash
+exit()
+# generate data form directory using pretrained fasttext model
+python data/generate_custom_dataset.py --data_dir data/sixray_sample --fasttext_model /data/fasttext/cc.en.300.bin
+
+# generate data form SQLite using pretrained fasttext model
+python data/generate_custom_dataset.py --data_dir data/sixray_500 --fasttext_model /data/fasttext/cc.en.300.bin --clean --copy_images --dataroot /data/Sixray_easy --sqlite /data/sixray_caption_db/| \n", + " | id | \n", + "file_id | \n", + "caption | \n", + "user | \n", + "is_error | \n", + "is_occluded | \n", + "one | \n", + "
|---|---|---|---|---|---|---|---|
| 0 | \n", + "1 | \n", + "P00001.jpg | \n", + "Two knives, one placed on top of other, in a b... | \n", + "anshul | \n", + "0 | \n", + "1 | \n", + "1 | \n", + "
| 1 | \n", + "2 | \n", + "P00001.jpg | \n", + "A bag with knives.\\nTwo shrap knives in a back... | \n", + "soumen | \n", + "0 | \n", + "1 | \n", + "1 | \n", + "
| 2 | \n", + "3 | \n", + "P00001.jpg | \n", + "Security discovered a concealed knife in the p... | \n", + "soumen | \n", + "0 | \n", + "1 | \n", + "1 | \n", + "
| 3 | \n", + "5 | \n", + "P00004.jpg | \n", + "Two knives are hidden inside a backpack. | \n", + "soumen | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "
| 4 | \n", + "6 | \n", + "P00005.jpg | \n", + "Two knives are hidden inside a backpack overla... | \n", + "soumen | \n", + "0 | \n", + "0 | \n", + "1 | \n", + "
| \n", + " | id | \n", + "file_id | \n", + "caption | \n", + "user | \n", + "is_error | \n", + "is_occluded | \n", + "
|---|---|---|---|---|---|---|
| 0 | \n", + "1 | \n", + "P00001.jpg | \n", + "Two knives, one placed on top of other, in a b... | \n", + "anshul | \n", + "0 | \n", + "1 | \n", + "
| 1 | \n", + "2 | \n", + "P00001.jpg | \n", + "A bag with knives.\\nTwo shrap knives in a back... | \n", + "soumen | \n", + "0 | \n", + "1 | \n", + "
| 2 | \n", + "3 | \n", + "P00001.jpg | \n", + "Security discovered a concealed knife in the p... | \n", + "soumen | \n", + "0 | \n", + "1 | \n", + "
| 3 | \n", + "5 | \n", + "P00004.jpg | \n", + "Two knives are hidden inside a backpack. | \n", + "soumen | \n", + "0 | \n", + "0 | \n", + "
| 4 | \n", + "6 | \n", + "P00005.jpg | \n", + "Two knives are hidden inside a backpack overla... | \n", + "soumen | \n", + "0 | \n", + "0 | \n", + "
| \n", + " | Thumbnail | \n", + "File Name | \n", + "NA | \n", + "Is in 500 collection? | \n", + "Map File Name | \n", + "
|---|---|---|---|---|---|
| 0 | \n", + "NaN | \n", + "P00001.jpg | \n", + "NaN | \n", + "1 | \n", + "NaN | \n", + "
| 1 | \n", + "NaN | \n", + "P00002.jpg | \n", + "NaN | \n", + "0 | \n", + "P00001.jpg | \n", + "
| 2 | \n", + "NaN | \n", + "P00004.jpg | \n", + "NaN | \n", + "1 | \n", + "NaN | \n", + "
| 3 | \n", + "NaN | \n", + "P00005.jpg | \n", + "NaN | \n", + "1 | \n", + "NaN | \n", + "
| 4 | \n", + "NaN | \n", + "P00006.jpg | \n", + "NaN | \n", + "1 | \n", + "NaN | \n", + "
| \n", + " | File Name | \n", + "Map File Name | \n", + "
|---|---|---|
| 1 | \n", + "P00002.jpg | \n", + "P00001.jpg | \n", + "
| 7 | \n", + "P00009.jpg | \n", + "P00008.jpg | \n", + "
| 10 | \n", + "P00012.jpg | \n", + "P00011.jpg | \n", + "
| 13 | \n", + "P00015.jpg | \n", + "P00011.jpg | \n", + "
| 18 | \n", + "P00022.jpg | \n", + "P00021.jpg | \n", + "
| \n", + " | id | \n", + "file_id | \n", + "caption | \n", + "user | \n", + "is_error | \n", + "is_occluded | \n", + "
|---|---|---|---|---|---|---|
| 0 | \n", + "1 | \n", + "P00001.jpg | \n", + "Two knives, one placed on top of other, in a b... | \n", + "anshul | \n", + "0 | \n", + "1 | \n", + "
| 1 | \n", + "2 | \n", + "P00001.jpg | \n", + "A bag with knives.\\nTwo shrap knives in a back... | \n", + "soumen | \n", + "0 | \n", + "1 | \n", + "
| 2 | \n", + "3 | \n", + "P00001.jpg | \n", + "Security discovered a concealed knife in the p... | \n", + "soumen | \n", + "0 | \n", + "1 | \n", + "
| 3 | \n", + "5 | \n", + "P00004.jpg | \n", + "Two knives are hidden inside a backpack. | \n", + "soumen | \n", + "0 | \n", + "0 | \n", + "
| 4 | \n", + "6 | \n", + "P00005.jpg | \n", + "Two knives are hidden inside a backpack overla... | \n", + "soumen | \n", + "0 | \n", + "0 | \n", + "