-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
executable file
·1683 lines (1424 loc) · 58.9 KB
/
__init__.py
File metadata and controls
executable file
·1683 lines (1424 loc) · 58.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os.path
import requests
import traceback
import threading
import websocket # type: ignore
import json
import math
import re
import base64
import numpy as np # type: ignore
from PIL import Image # type: ignore
import base64 as b64 # type: ignore
import numbers
import six
from six import string_types
from six import BytesIO
from six.moves import urllib
import logging
import warnings
import time
import errno
import io
from functools import wraps
try:
import bs4 # type: ignore
BS4_AVAILABLE = True
except ImportError:
BS4_AVAILABLE = False
here = os.path.abspath(os.path.dirname(__file__))
try:
with open(os.path.join(here, 'VERSION')) as version_file:
__version__ = version_file.read().strip()
except Exception:
__version__ = 'no_version_file'
try:
import torchfile # type: ignore
except BaseException:
from . import torchfile
try:
raise ConnectionError()
except NameError: # python 2 doesn't have ConnectionError
class ConnectionError(Exception):
pass
except ConnectionError:
pass
logging.getLogger('requests').setLevel(logging.CRITICAL)
logging.getLogger('urllib3').setLevel(logging.CRITICAL)
logger = logging.getLogger(__name__)
def isstr(s):
return isinstance(s, string_types)
def isnum(n):
return isinstance(n, numbers.Number)
def isndarray(n):
return isinstance(n, (np.ndarray))
def nan2none(l):
for idx, val in enumerate(l):
if math.isnan(val):
l[idx] = None
return l
def from_t7(t, b64=False):
if b64:
t = base64.b64decode(t)
with open('/dev/shm/t7', 'wb') as ff:
ff.write(t)
ff.close()
sf = open('/dev/shm/t7', 'rb')
return torchfile.T7Reader(sf).read_obj()
def loadfile(filename):
assert os.path.isfile(filename), 'could not find file %s' % filename
fileobj = open(filename, 'rb')
assert fileobj, 'could not open file %s' % filename
str = fileobj.read()
fileobj.close()
return str
def _title2str(opts):
if opts.get('title'):
if isnum(opts.get('title')):
title = str(opts.get('title'))
logger.warn('Numerical title %s has been casted to a string' % \
title)
opts['title'] = title
return opts
else:
return opts
def _scrub_dict(d):
if type(d) is dict:
return {k: _scrub_dict(v) for k, v in list(d.items())
if v is not None and _scrub_dict(v) is not None}
else:
return d
def _axisformat(xy, opts):
fields = ['type', 'label', 'tickmin', 'tickmax', 'tickvals', 'ticklabels',
'tick', 'tickfont']
if any(opts.get(xy + i) for i in fields):
has_ticks = (opts.get(xy + 'tickmin') and opts.get(xy + 'tickmax')) \
is not None
return {
'type': opts.get(xy + 'type'),
'title': opts.get(xy + 'label'),
'range': [opts.get(xy + 'tickmin'),
opts.get(xy + 'tickmax')] if has_ticks else None,
'tickvals': opts.get(xy + 'tickvals'),
'ticktext': opts.get(xy + 'ticklabels'),
'dtick': opts.get(xy + 'tickstep'),
'showticklabels': opts.get(xy + 'tick'),
'tickfont': opts.get(xy + 'tickfont'),
}
def _opts2layout(opts, is3d=False):
layout = {
'showlegend': opts.get('showlegend', 'legend' in opts),
'title': opts.get('title'),
'xaxis': _axisformat('x', opts),
'yaxis': _axisformat('y', opts),
'margin': {
'l': opts.get('marginleft', 60),
'r': opts.get('marginright', 60),
't': opts.get('margintop', 60),
'b': opts.get('marginbottom', 60),
}
}
if is3d:
layout['zaxis'] = _axisformat('z', opts)
if opts.get('stacked'):
layout['barmode'] = 'stack' if opts.get('stacked') else 'group'
layout_opts = opts.get('layoutopts')
if layout_opts is not None:
if 'plotly' in layout_opts:
layout.update(layout_opts['plotly'])
return _scrub_dict(layout)
def _markerColorCheck(mc, X, Y, L):
assert isndarray(mc), 'mc should be a numpy ndarray'
assert mc.shape[0] == L or (
mc.shape[0] == X.shape[0] and
(mc.ndim == 1 or mc.ndim == 2 and mc.shape[1] == 3)), \
('marker colors have to be of size `%d` or `%d x 3` ' +
' or `%d` or `%d x 3`, but got: %s') % \
(X.shape[0], X.shape[1], L, L, 'x'.join(map(str, mc.shape)))
assert (mc >= 0).all(), 'marker colors have to be >= 0'
assert (mc <= 255).all(), 'marker colors have to be <= 255'
assert (mc == np.floor(mc)).all(), 'marker colors are assumed to be ints'
mc = np.uint8(mc)
if mc.ndim == 1:
markercolor = ['rgba(0, 0, 255, %s)' % (mc[i] / 255.)
for i in range(len(mc))]
else:
markercolor = ['#%02x%02x%02x' % (i[0], i[1], i[2]) for i in mc]
if mc.shape[0] != X.shape[0]:
markercolor = [markercolor[Y[i] - 1] for i in range(Y.shape[0])]
ret = {}
for k, v in enumerate(markercolor):
ret[Y[k]] = ret.get(Y[k], []) + [v]
return ret
def _assert_opts(opts):
if opts.get('color'):
assert isstr(opts.get('color')), 'color should be a string'
if opts.get('colormap'):
assert isstr(opts.get('colormap')), \
'colormap should be string'
if opts.get('mode'):
assert isstr(opts.get('mode')), 'mode should be a string'
if opts.get('markersymbol'):
assert isstr(opts.get('markersymbol')), \
'marker symbol should be string'
if opts.get('markersize'):
assert isnum(opts.get('markersize')) \
and opts.get('markersize') > 0, \
'marker size should be a positive number'
if opts.get('columnnames'):
assert isinstance(opts.get('columnnames'), list), \
'columnnames should be a table with column names'
if opts.get('rownames'):
assert isinstance(opts.get('rownames'), list), \
'rownames should be a table with row names'
if opts.get('jpgquality'):
assert isnum(opts.get('jpgquality')), \
'JPG quality should be a number'
assert opts.get('jpgquality') > 0 and opts.get('jpgquality') <= 100, \
'JPG quality should be number between 0 and 100'
if opts.get('opacity'):
assert isnum(opts.get('opacity')), 'opacity should be a number'
assert 0 <= opts.get('opacity') <= 1, \
'opacity should be a number between 0 and 1'
if opts.get('fps'):
assert isnum(opts.get('fps')), 'fps should be a number'
assert opts.get('fps') > 0, 'fps must be greater than 0'
if opts.get('title'):
assert isstr(opts.get('title')), 'title should be a string'
torch_types = []
try:
import torch
torch_types.append(torch.Tensor)
torch_types.append(torch.nn.Parameter)
except (ImportError, AttributeError):
pass
def _to_numpy(a):
if isinstance(a, list):
return np.array(a)
if len(torch_types) > 0:
if isinstance(a, torch.autograd.Variable):
# For PyTorch < 0.4 comptability.
warnings.warn(
"Support for versions of PyTorch less than 0.4 is deprecated and "
"will eventually be removed.", DeprecationWarning)
a = a.data
for kind in torch_types:
if isinstance(a, kind):
# For PyTorch < 0.4 comptability, where non-Variable
# tensors do not have a 'detach' method. Will be removed.
if hasattr(a, 'detach'):
a = a.detach()
return a.cpu().numpy()
return a
def pytorch_wrap(f):
@wraps(f)
def wrapped_f(*args, **kwargs):
args = (_to_numpy(arg) for arg in args)
kwargs = {k: _to_numpy(v) for (k, v) in kwargs.items()}
return f(*args, **kwargs)
return wrapped_f
class Visdom(object):
def __init__(
self,
server='http://localhost',
endpoint='events',
port=8097,
ipv6=True,
http_proxy_host=None,
http_proxy_port=None,
env='main',
send=True,
raise_exceptions=None,
use_incoming_socket=True,
log_to_filename=None,
):
self.server_base_name = server[server.index("//") + 2:]
self.server = server
self.endpoint = endpoint
self.port = port
self.ipv6 = ipv6
self.http_proxy_host = http_proxy_host
self.http_proxy_port = http_proxy_port
self.env = env # default env
self.send = send
self.event_handlers = {} # Haven't registered any events
self.socket_alive = False
self.use_socket = use_incoming_socket
# Flag to indicate whether to raise errors or suppress them
self.raise_exceptions = raise_exceptions
self.log_to_filename = log_to_filename
self._send({
'eid': env,
}, endpoint='env/' + env)
# when talking to a server, get a backchannel
if send and use_incoming_socket:
self.setup_socket()
elif send and not use_incoming_socket:
logger.warn(
'Without the incoming socket you cannot receive events from '
'the server or register event handlers to your Visdom client.'
)
# Wait for initialization before starting
time_spent = 0
inc = 0.1
while self.use_socket and not self.socket_alive and time_spent < 5:
time.sleep(inc)
time_spent += inc
inc *= 2
if time_spent > 5:
logger.warn(
'Visdom python client failed to establish socket to get '
'messages from the server. This feature is optional and '
'can be disabled by initializing Visdom with '
'`use_incoming_socket=False`, which will prevent waiting for '
'this request to timeout.'
)
def register_event_handler(self, handler, target):
assert callable(handler), 'Event handler must be a function'
assert self.use_socket, 'Must be using the incoming socket to '\
'register events to web actions'
if target not in self.event_handlers:
self.event_handlers[target] = []
self.event_handlers[target].append(handler)
def clear_event_handlers(self, target):
self.event_handlers[target] = []
def setup_socket(self):
# Setup socket to server
def on_message(ws, message):
message = json.loads(message)
if 'command' in message:
# Handle server commands
if message['command'] == 'alive':
if 'data' in message and message['data'] == 'vis_alive':
logger.info('Visdom successfully connected to server')
self.socket_alive = True
else:
logger.warn('Visdom server failed handshake, may not '
'be properly connected')
if 'target' in message:
for handler in list(
self.event_handlers.get(message['target'], [])):
handler(message)
def on_error(ws, error):
try:
if error.errno == errno.ECONNREFUSED:
logger.info(
"Socket refused connection, running socketless")
ws.close()
self.use_socket = False
else:
logger.error(error)
except AttributeError:
logger.error(error)
def on_close(ws):
self.socket_alive = False
def run_socket(*args):
host_scheme = urllib.parse.urlparse(self.server).scheme
if host_scheme == "https":
ws_scheme = "wss"
else:
ws_scheme = "ws"
while self.use_socket:
try:
sock_addr = "{}://{}:{}/vis_socket".format(
ws_scheme, self.server_base_name, self.port)
ws = websocket.WebSocketApp(
sock_addr,
on_message=on_message,
on_error=on_error,
on_close=on_close
)
ws.run_forever(http_proxy_host=self.http_proxy_host,
http_proxy_port=self.http_proxy_port)
ws.close()
except Exception as e:
logger.error(
'Socket had error {}, attempting restart'.format(e))
time.sleep(3)
# Start listening thread
self.socket_thread = threading.Thread(
target=run_socket,
name='Visdom-Socket-Thread'
)
self.socket_thread.daemon = True
self.socket_thread.start()
# Utils
def _send(self, msg, endpoint='events', quiet=False, from_log=False):
"""
This function sends specified JSON request to the Tornado server. This
function should generally not be called by the user, unless you want to
build the required JSON yourself. `endpoint` specifies the destination
Tornado server endpoint for the request.
"""
if msg.get('eid', None) is None:
msg['eid'] = self.env
if not self.send:
return msg, endpoint
try:
r = requests.post(
"{0}:{1}/{2}".format(self.server, self.port, endpoint),
data=json.dumps(msg),
)
if self.log_to_filename is not None and not from_log:
if endpoint in ['events', 'update']:
if msg['win'] is None:
msg['win'] = r.text
with open(self.log_to_filename, 'a+') as log_file:
log_file.write(json.dumps([
endpoint,
msg,
]) + '\n')
return r.text
except BaseException:
if self.raise_exceptions:
raise ConnectionError("Error connecting to Visdom server")
else:
if self.raise_exceptions is None:
warnings.warn(
"Visdom is eventually changing to default to raising "
"exceptions rather than ignoring/printing. This change"
" is expected to happen by July 2018. Please set "
"`raise_exceptions` to False to retain current "
"behavior.",
PendingDeprecationWarning
)
if not quiet:
print("Exception in user code:")
print('-' * 60)
traceback.print_exc()
return False
def save(self, envs):
"""
This function allows the user to save envs that are alive on the
Tornado server. The envs can be specified as a table (list) of env ids.
"""
assert isinstance(envs, list), 'envs should be a list'
if len(envs) > 0:
for env in envs:
assert isstr(env), 'env should be a string'
return self._send({
'data': envs,
}, 'save')
def get_window_data(self, win=None, env=None):
"""
This function returns all the window data for a specified window in
an environment. Use win=None to get all the windows in the given
environment. Env defaults to main
"""
return self._send(
msg={'win': win, 'eid': env},
endpoint='win_data'
)
def close(self, win=None, env=None):
"""
This function closes a specific window.
Use `win=None` to close all windows in an env.
"""
return self._send(
msg={'win': win, 'eid': env},
endpoint='close'
)
def delete_env(self, env):
"""This function deletes a specific environment."""
return self._send(
msg={'eid': env},
endpoint='delete_env'
)
def _win_exists_wrap(self, win, env=None):
"""
This function returns a string indicating whether
or not a window exists on the server already. ['true' or 'false']
Returns False if something went wrong
"""
assert win is not None
return self._send({
'win': win,
'eid': env,
}, endpoint='win_exists', quiet=True)
def win_exists(self, win, env=None):
"""
This function returns a bool indicating whether
or not a window exists on the server already.
Returns None if something went wrong
"""
try:
e = self._win_exists_wrap(win, env)
except ConnectionError:
print("Error connecting to Visdom server!")
return None
if e == 'true':
return True
elif e == 'false':
return False
else:
return None
def _win_hash_wrap(self, win, env=None):
"""
This function returns a hash of the contents of
the window if the window exists.
Return None otherwise.
"""
assert win is not None
return self._send({
'win': win,
'env': env,
}, endpoint='win_hash', quiet=True)
def win_hash(self, win, env=None):
"""
This function returns md5 hash of the contents
of a window if it exists on the server.
Returns None, otherwise
"""
try:
e = self._win_hash_wrap(win, env)
except ConnectionError:
print("Error connecting to Visdom server!")
return None
if re.match(r"([a-fA-F\d]{32})", e):
return e
return None
def check_connection(self):
"""
This function returns a bool indicating whether or
not the server is connected.
"""
return (self.win_exists('') is not None) and \
(self.socket_alive or not self.use_socket)
def replay_log(self, log_filename):
"""
This function takes the contents of a visdom log and replays them to
the current server to restore the state or handle any missing entries.
"""
with open(log_filename) as f:
log_entries = f.readlines()
for entry in log_entries:
endpoint, msg = json.loads(entry)
self._send(msg, endpoint, from_log=True)
# Content
def text(self, text, win=None, env=None, opts=None, append=False):
"""
This function prints text in a box. It takes as input an `text` string.
No specific `opts` are currently supported.
"""
opts = {} if opts is None else opts
_title2str(opts)
_assert_opts(opts)
data = [{'content': text, 'type': 'text'}]
if append:
endpoint = 'update'
else:
endpoint = 'events'
return self._send({
'data': data,
'win': win,
'eid': env,
'opts': opts,
}, endpoint=endpoint)
def properties(self, data, win=None, env=None, opts=None):
"""
This function shows editable properties in a pane.
Properties are expected to be a List of Dicts e.g.:
```
properties = [
{'type': 'text', 'name': 'Text input', 'value': 'initial'},
{'type': 'number', 'name': 'Number input', 'value': '12'},
{'type': 'button', 'name': 'Button', 'value': 'Start'},
{'type': 'checkbox', 'name': 'Checkbox', 'value': True},
{'type': 'select', 'name': 'Select', 'value': 1,
'values': ['Red', 'Green', 'Blue']},
]
```
Supported types:
- text: string
- number: decimal number
- button: button labeled with "value"
- checkbox: boolean value rendered as a checkbox
- select: multiple values select box
- `value`: id of selected value (zero based)
- `values`: list of possible values
Callback are called on property value update:
- `event_type`: `"PropertyUpdate"`
- `propertyId`: position in the `properties` list
- `value`: new value
No specific `opts` are currently supported.
"""
opts = {} if opts is None else opts
_assert_opts(opts)
data = [{'content': data, 'type': 'properties'}]
return self._send({
'data': data,
'win': win,
'eid': env,
'opts': opts,
}, endpoint='events')
@pytorch_wrap
def svg(self, svgstr=None, svgfile=None, win=None, env=None, opts=None):
"""
This function draws an SVG object. It takes as input an SVG string or
the name of an SVG file. The function does not support any
plot-specific `opts`.
"""
opts = {} if opts is None else opts
_title2str(opts)
_assert_opts(opts)
if svgfile is not None:
svgstr = str(loadfile(svgfile))
assert svgstr is not None, 'should specify SVG string or filename'
svg = re.search('<svg .+</svg>', svgstr, re.DOTALL)
assert svg is not None, 'could not parse SVG string'
return self.text(text=svg.group(0), win=win, env=env, opts=opts)
def matplot(self, plot, opts=None, env=None, win=None):
"""
This function draws a Matplotlib `plot`. The function supports
one plot-specific option: `resizable`. When set to `True` the plot
is resized with the pane. You need `beautifulsoup4` and `lxml`
packages installed to use this option.
"""
opts = {} if opts is None else opts
_title2str(opts)
_assert_opts(opts)
# write plot to SVG buffer:
buffer = io.StringIO()
plot.savefig(buffer, format='svg')
buffer.seek(0)
svg = buffer.read()
buffer.close()
if opts.get('resizable', False):
if not BS4_AVAILABLE:
raise ImportError("No module named 'bs4'")
else:
try:
soup = bs4.BeautifulSoup(svg, 'xml')
except bs4.FeatureNotFound as e:
six.raise_from(ImportError("No module named 'lxml'"), e)
height = soup.svg.attrs.pop('height', None)
width = soup.svg.attrs.pop('width', None)
svg = str(soup)
else:
height = None
width = None
# show SVG:
if 'height' not in opts:
height = height or re.search('height\="([0-9\.]*)pt"', svg)
if height is not None:
opts['height'] = 1.4 * int(math.ceil(float(height.group(1))))
if 'width' not in opts:
width = width or re.search('width\="([0-9\.]*)pt"', svg)
if width is not None:
opts['width'] = 1.35 * int(math.ceil(float(width.group(1))))
return self.svg(svgstr=svg, opts=opts, env=env, win=win)
def plotlyplot(self, figure, win=None, env=None):
"""
This function draws a Plotly 'Figure' object. It does not explicitly
take options as it assumes you have already explicitly configured the
figure's layout.
Note: You must have the 'plotly' Python package installed to use
this function.
"""
try:
import plotly
# We do a round-trip of JSON encoding and decoding to make use of
# the Plotly JSON Encoder. The JSON encoder deals with converting
# numpy arrays to Python lists and several other edge cases.
figure_dict = json.loads(
json.dumps(figure, cls=plotly.utils.PlotlyJSONEncoder))
return self._send({
'data': figure_dict['data'],
'layout': figure_dict['layout'],
'win': win,
'eid': env
})
except ImportError:
raise RuntimeError(
"Plotly must be installed to plot Plotly figures")
@pytorch_wrap
def image(self, img, win=None, env=None, opts=None):
"""
This function draws an img. It takes as input an `CxHxW` or `HxW` tensor
`img` that contains the image. The array values can be float in [0,1] or
uint8 in [0, 255].
"""
opts = {} if opts is None else opts
_title2str(opts)
_assert_opts(opts)
opts['width'] = opts.get('width', img.shape[img.ndim - 1])
opts['height'] = opts.get('height', img.shape[img.ndim - 2])
nchannels = img.shape[0] if img.ndim == 3 else 1
if nchannels == 1:
img = np.squeeze(img)
img = img[np.newaxis, :, :].repeat(3, axis=0)
if 'float' in str(img.dtype):
if img.max() <= 1:
img = img * 255.
img = np.uint8(img)
img = np.transpose(img, (1, 2, 0))
im = Image.fromarray(img)
buf = BytesIO()
im.save(buf, format='PNG')
b64encoded = b64.b64encode(buf.getvalue()).decode('utf-8')
data = [{
'content': {
'src': 'data:image/png;base64,' + b64encoded,
'caption': opts.get('caption'),
},
'type': 'image',
}]
return self._send({
'data': data,
'win': win,
'eid': env,
'opts': opts,
})
@pytorch_wrap
def images(self, tensor, nrow=8, padding=2,
win=None, env=None, opts=None):
"""
Given a 4D tensor of shape (B x C x H x W),
or a list of images all of the same size,
makes a grid of images of size (B / nrow, nrow).
This is a modified from `make_grid()`
https://github.com/pytorch/vision/blob/master/torchvision/utils.py
"""
# If list of images, convert to a 4D tensor
if isinstance(tensor, list):
tensor = np.stack(tensor, 0)
if tensor.ndim == 2: # single image H x W
tensor = np.expand_dims(tensor, 0)
if tensor.ndim == 3: # single image
if tensor.shape[0] == 1: # if single-channel, convert to 3-channel
tensor = np.repeat(tensor, 3, 0)
return self.image(tensor, win, env, opts)
if tensor.ndim == 4 and tensor.shape[1] == 1: # single-channel images
tensor = np.repeat(tensor, 3, 1)
# make 4D tensor of images into a grid
nmaps = tensor.shape[0]
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height = int(tensor.shape[2] + 2 * padding)
width = int(tensor.shape[3] + 2 * padding)
grid = np.ones([3, height * ymaps, width * xmaps])
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
h_start = y * height + 1 + padding
h_end = h_start + tensor.shape[2]
w_start = x * width + 1 + padding
w_end = w_start + tensor.shape[3]
grid[:, h_start:h_end, w_start:w_end] = tensor[k]
k += 1
return self.image(grid, win, env, opts)
@pytorch_wrap
def audio(self, tensor=None, audiofile=None, win=None, env=None, opts=None):
"""
This function plays audio. It takes as input the filename of the audio
file or an `N` tensor containing the waveform (use an `Nx2` matrix for
stereo audio). The function does not support any plot-specific `opts`.
The following `opts` are supported:
- `opts.sample_frequency`: sample frequency (`integer` > 0; default = 44100)
"""
opts = {} if opts is None else opts
opts['sample_frequency'] = opts.get('sample_frequency', 44100)
_title2str(opts)
_assert_opts(opts)
assert tensor is not None or audiofile is not None, \
'should specify audio tensor or file'
if tensor is not None:
assert tensor.ndim == 1 or (tensor.ndim == 2 and tensor.shape[1] == 2), \
'tensor should be 1D vector or 2D matrix with 2 columns'
if tensor is not None:
import scipy.io.wavfile # type: ignore
import tempfile
audiofile = '/tmp/%s.wav' % next(tempfile._get_candidate_names())
tensor = np.int16(tensor / np.max(np.abs(tensor)) * 32767)
scipy.io.wavfile.write(audiofile, opts.get('sample_frequency'), tensor)
extension = audiofile.split('.')[-1].lower()
mimetypes = {'wav': 'wav', 'mp3': 'mp3', 'ogg': 'ogg', 'flac': 'flac'}
mimetype = mimetypes.get(extension)
assert mimetype is not None, 'unknown audio type: %s' % extension
bytestr = loadfile(audiofile)
videodata = """
<audio controls>
<source type="audio/%s" src="data:audio/%s;base64,%s">
Your browser does not support the audio tag.
</audio>
""" % (mimetype, mimetype, base64.b64encode(bytestr).decode('utf-8'))
opts['height'] = 80
opts['width'] = 330
return self.text(text=videodata, win=win, env=env, opts=opts)
@pytorch_wrap
def video(self, tensor=None, videofile=None, win=None, env=None, opts=None):
"""
This function plays a video. It takes as input the filename of the video
`videofile` or a `LxHxWxC`-sized `tensor` containing all the frames of
the video as input. The function does not support any plot-specific `opts`.
The following `opts` are supported:
- `opts.fps`: FPS for the video (`integer` > 0; default = 25)
"""
opts = {} if opts is None else opts
opts['fps'] = opts.get('fps', 25)
_title2str(opts)
_assert_opts(opts)
assert tensor is not None or videofile is not None, \
'should specify video tensor or file'
if tensor is not None:
import cv2 # type: ignore
import tempfile
assert tensor.ndim == 4, 'video should be in 4D tensor'
videofile = '/tmp/%s.ogv' % next(tempfile._get_candidate_names())
if cv2.__version__.startswith('2'): # OpenCV 2
fourcc = cv2.cv.CV_FOURCC(
chr(ord('T')),
chr(ord('H')),
chr(ord('E')),
chr(ord('O'))
)
elif cv2.__version__.startswith('3'): # OpenCV 3
fourcc = cv2.VideoWriter_fourcc(
chr(ord('T')),
chr(ord('H')),
chr(ord('E')),
chr(ord('O'))
)
writer = cv2.VideoWriter(
videofile,
fourcc,
opts.get('fps'),
(tensor.shape[2], tensor.shape[1])
)
assert writer.isOpened(), 'video writer could not be opened'
for i in range(tensor.shape[0]):
# TODO mute opencv on this function call somehow
writer.write(tensor[i, :, :, :])
writer.release()
writer = None
extension = videofile.split(".")[-1].lower()
mimetypes = {'mp4': 'mp4', 'ogv': 'ogg', 'avi': 'avi', 'webm': 'webm'}
mimetype = mimetypes.get(extension)
assert mimetype is not None, 'unknown video type: %s' % extension
bytestr = loadfile(videofile)
videodata = """
<video controls>
<source type="video/%s" src="data:video/%s;base64,%s">
Your browser does not support the video tag.
</video>
""" % (mimetype, mimetype, base64.b64encode(bytestr).decode('utf-8'))
return self.text(text=videodata, win=win, env=env, opts=opts)
def update_window_opts(self, win, opts, env=None):
"""
This function allows pushing new options to an existing plot window
without updating the content
"""
data_to_send = {
'win': win,
'eid': env,
'layout': _opts2layout(opts),
'opts': opts,
}
return self._send(data_to_send, endpoint='update')
@pytorch_wrap
def scatter(self, X, Y=None, win=None, env=None, opts=None, update=None,
name=None):
"""
This function draws a 2D or 3D scatter plot. It takes in an `Nx2` or
`Nx3` tensor `X` that specifies the locations of the `N` points in the
scatter plot. An optional `N` tensor `Y` containing discrete labels that
range between `1` and `K` can be specified as well -- the labels will be
reflected in the colors of the markers.
`update` can be used to efficiently update the data of an existing plot.
Use 'append' to append data, 'replace' to use new data, and 'remove' to
delete the trace that is specified in `name`. If updating a single
trace, use `name` to specify the name of the trace to be updated.
Update data that is all NaN is ignored (can be used for masking update).
Using `update='append'` will create a plot if it doesn't exist
and append to the existing plot otherwise.
The following `opts` are supported:
- `opts.markersymbol`: marker symbol (`string`; default = `'dot'`)
- `opts.markersize` : marker size (`number`; default = `'10'`)
- `opts.markercolor` : marker color (`np.array`; default = `None`)
- `opts.textlabels` : text label for each point (`list`: default = `None`)
- `opts.legend` : `table` containing legend names
"""
if update == 'remove':
assert win is not None
assert name is not None, 'A trace must be specified for deletion'
assert opts is None, 'Opts cannot be updated on trace deletion'
data_to_send = {
'data': [],