-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlsg_test.py
More file actions
28 lines (24 loc) · 949 Bytes
/
lsg_test.py
File metadata and controls
28 lines (24 loc) · 949 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from networks.model.models import srm, lsg
import os
import torch
from utils import sample, draw, l_sample
torch.set_float32_matmul_precision('medium')
experiment_name = 'Your name here'
SRM = srm.load_from_checkpoint("Models/First Run/SRM.ckpt")
LSG = lsg.load_from_checkpoint("Models/LSG-Train-run/epoch=322299-global_step=0.ckpt")
dim_in = 6
samples = 1000
size = 512
#Number of sampling steps for the LSG
steps = list(range(50))
if not os.path.exists("Samples/{}".format(experiment_name)):
os.makedirs("Samples/{}".format(experiment_name))
with torch.no_grad():
for i in range(1000):
#LSG
Latent = l_sample(steps, LSG.model, LSG.noise_scheduler_sample)
#SRM
stroke = sample(samples, SRM.sample_steps, SRM.decoder, SRM.noise_scheduler_sample, Latent, dim_in)
#Render
filename = 'Samples/{}/{}.svg'.format(experiment_name, i)
draw(SRM.format, size, filename, stroke)