-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathguideddiffusion.py
More file actions
37 lines (27 loc) · 1.48 KB
/
guideddiffusion.py
File metadata and controls
37 lines (27 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# -*- coding: utf-8 -*-
"""GuidedDiffusion.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1vg7xaBAv_Zv1WBLwai9VKlfC9nOg5bo1
"""
# Commented out IPython magic to ensure Python compatibility.
!git clone https://github.com/openai/guided-diffusion
# %cd /content/guided-diffusion
from google.colab import drive
drive.mount('/content/gdrive')
# Commented out IPython magic to ensure Python compatibility.
# %mkdir /content/guided-diffusion/data/
# %cp /content/gdrive/MyDrive/ISTD_Dataset.zip /content/guided-diffusion/data/
!unzip -q /content/guided-diffusion/data/ISTD_Dataset.zip
# %cd /content/guided-diffusion
# Commented out IPython magic to ensure Python compatibility.
# %mv /content/guided-diffusion/ISTD_Dataset/* /content/guided-diffusion/data/
!pip install -e .
!pip install mpi4py
TRAIN_FLAGS="--iterations 3 --anneal_lr True --batch_size 256 --lr 3e-4 --save_interval 10 --weight_decay 0.05"
CLASSIFIER_FLAGS="--image_size 128 --classifier_attention_resolutions 32,16,8 --classifier_depth 2 --classifier_width 128 --classifier_pool attention --classifier_resblock_updown True --classifier_use_scale_shift_norm True"
# Commented out IPython magic to ensure Python compatibility.
# %%time
#
# !mpiexec --allow-run-as-root python scripts/classifier_train.py --data_dir /content/guided-diffusion/data/train/train_C $TRAIN_FLAGS $CLASSIFIER_FLAGS
# python scripts/classifier_train.py --data_dir data/ $TRAIN_FLAGS $CLASSIFIER_FLAGS