-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathsample_model.py
More file actions
executable file
·177 lines (129 loc) · 5.93 KB
/
sample_model.py
File metadata and controls
executable file
·177 lines (129 loc) · 5.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
# Hack so you don't have to put the library containing this script in the PYTHONPATH.
sys.path = [os.path.abspath(os.path.join(__file__, '..', '..'))] + sys.path
import numpy as np
from os.path import join as pjoin
import argparse
import theano
from iRBM.misc import utils
from iRBM.misc.utils import Timer
import pylab as plt
from iRBM.misc import vizu
def buildArgsParser():
DESCRIPTION = ("Script to sample from an RBM-like model.")
p = argparse.ArgumentParser(description=DESCRIPTION)
p.add_argument('name', type=str, help='name/path of the experiment.')
# Sampling options
sampling = p.add_argument_group("Sampling arguments")
sampling.add_argument('--nb-samples', metavar='M', type=int,
help='number of samples. Default=16', default=16)
sampling.add_argument('--cdk', metavar='K', type=int,
help='number of Gibbs steps. Default=10000', default=10000)
sampling.add_argument('--full-gibbs-step', action='store_true',
help='if specified, use heuristic z=K for the first Gibbs step.')
sampling.add_argument('--seed', type=int,
help='seed used to generate random numbers. Default=1234.', default=1234)
# General options (optional)
general = p.add_argument_group("General arguments")
general.add_argument('--view', action='store_true',
help='display the samples.')
general.add_argument('--save', action='store_true',
help='save the samples.')
general.add_argument('--out', metavar='FILE', type=str,
help='file where samples will be saved. Default=samples.npz', default="samples.npy")
return p
def main():
parser = buildArgsParser()
args = parser.parse_args()
# Check that a least one of --view or --save has been given.
if not args.view and not args.save:
parser.error("At least one the following options must be chosen: --view or --save")
# Get experiment folder
experiment_path = args.name
if not os.path.isdir(experiment_path):
# If not a directory, it must be the name of the experiment.
experiment_path = pjoin(".", "experiments", args.name)
if not os.path.isdir(experiment_path):
parser.error('Cannot find experiment: {0}!'.format(args.name))
if not os.path.isfile(pjoin(experiment_path, "model.pkl")):
parser.error('Cannot find model for experiment: {0}!'.format(experiment_path))
if not os.path.isfile(pjoin(experiment_path, "hyperparams.json")):
parser.error('Cannot find hyperparams for experiment: {0}!'.format(experiment_path))
# Load experiments hyperparameters
hyperparams = utils.load_dict_from_json_file(pjoin(experiment_path, "hyperparams.json"))
with Timer("Loading model"):
if hyperparams["model"] == "rbm":
from iRBM.models.rbm import RBM
model_class = RBM
elif hyperparams["model"] == "orbm":
from iRBM.models.orbm import oRBM
model_class = oRBM
elif hyperparams["model"] == "irbm":
from iRBM.models.irbm import iRBM
model_class = iRBM
# Load the actual model.
model = model_class.load(pjoin(experiment_path, "model.pkl"))
rng = np.random.RandomState(args.seed)
# Sample from uniform
# TODO: sample from Bernouilli distribution parametrized with visible biases
chain_start = (rng.rand(args.nb_samples, model.input_size) > 0.5).astype(theano.config.floatX)
with Timer("Building sampling function"):
v0 = theano.shared(np.asarray(chain_start, dtype=theano.config.floatX))
v1 = model.gibbs_step(v0)
gibbs_step = theano.function([], updates={v0: v1})
if args.full_gibbs_step:
print "Using z=K"
# Use z=K for first Gibbs step.
from iRBM.models.rbm import RBM
h0 = RBM.sample_h_given_v(model, v0)
v1 = RBM.sample_v_given_h(model, h0)
v0.set_value(v1.eval())
with Timer("Sampling"):
for k in range(args.cdk):
gibbs_step()
samples = v0.get_value()
if args.save:
np.savez(args.out, samples)
if args.view:
if hyperparams["dataset"] == "binarized_mnist":
image_shape = (28, 28)
elif hyperparams["dataset"] == "caltech101_silhouettes28":
image_shape = (28, 28)
else:
raise ValueError("Unknown dataset: {0}".format(hyperparams["dataset"]))
data = vizu.concatenate_images(samples, shape=image_shape, border_size=1, clim=(0, 1))
plt.imshow(data, cmap=plt.cm.gray, interpolation='nearest')
plt.show()
# def sample(chain_start=None, nb_samples=None, cdk=1, keep=1, all_active=(0, None), model=model):
# import numpy as np
# nb_samples = len(chain_start)
# model.batch_size = nb_samples
# print("Compiling function...")
# v0 = theano.shared(np.asarray(chain_start, dtype=theano.config.floatX))
# samples = []
# samples.append(v0.get_value())
# v1 = model.gibbs_step(v0)
# updates = {v0: v1}
# gibbs_step = theano.function([], updates=updates)
# # Sampling with all neurons active
# if hasattr(model, 'beta') and all_active[0] != 0:
# beta = model.beta
# model.beta = all_active[1]
# updates = {v0: model.gibbs_step(v0)}
# gibbs_full = theano.function([], updates=updates)
# for i in range(all_active[0]):
# gibbs_full()
# samples.append(v0.get_value())
# model.beta = beta
# print("Sampling...")
# for i in range(1, cdk+1):
# gibbs_step()
# if i % keep == 0:
# samples.append(v0.get_value())
# samples = np.array(samples)
# return samples
if __name__ == "__main__":
main()