-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsplitter.py
More file actions
36 lines (25 loc) · 888 Bytes
/
splitter.py
File metadata and controls
36 lines (25 loc) · 888 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
def split(tensor, indices, threshold):
'''
This method splits the given tensor into an object of the form
A_{ijk} B_{k,...}
and truncates the SVD according to the relative error threshold.
Here i and j are the given indices and k is the new internal index.
'''
indices = list(indices)
perm = list(range(len(tensor.shape)))
perm.remove(indices[0])
perm.remove(indices[1])
perm = indices + perm
t = np.transpose(tensor, axes=perm)
arr = np.reshape(t, (t.shape[0]*t.shape[1], np.product(t.shape[2:])))
u, s, v = np.linalg.svd(arr, full_matrices=False)
p = s**2 / np.sum(s**2)
s = s[p > threshold]
u = u[:,p > threshold]
v = v[p > threshold]
vs = np.dot(np.diag(s), v)
u = np.reshape(u, (t.shape[0], t.shape[1], len(s)))
v = np.reshape(v, [len(s)] + list(t.shape[2:]))
vs = np.reshape(vs, [len(s)] + list(t.shape[2:]))
return u, s, v, vs