Skip to content

Commit acc439e

Browse files
authored
Merge pull request tensorly#591 from dontempty/maxvol-adjustment
Update _tt_cross.py
2 parents 647c0b8 + c171151 commit acc439e

1 file changed

Lines changed: 4 additions & 7 deletions

File tree

tensorly/contrib/decomposition/_tt_cross.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -452,18 +452,15 @@ def maxvol(A):
452452

453453
# Find the row of max norm
454454
max_row_idx = tl.argmax(rows_norms, axis=0)
455-
max_row = A[rest_of_rows[max_row_idx], :]
455+
max_row = A_new[max_row_idx, :]
456456

457457
# Compute the projection of max_row to other rows
458-
# projection a to b is computed as: <a,b> / sqrt(|a|*|b|)
458+
# projection = <b, a>/|a|^2
459459
projection = tl.dot(A_new, tl.transpose(max_row))
460-
normalization = tl.sqrt(rows_norms[max_row_idx] * rows_norms)
461-
# make sure normalization vector is of the same shape of projection
462-
normalization = tl.reshape(normalization, tl.shape(projection))
463-
projection = projection / normalization
460+
projection = projection / (tl.sum(max_row**2))
464461

465462
# Subtract the projection from A_new: b <- b - a * projection
466-
A_new = A_new - A_new * tl.reshape(projection, (tl.shape(A_new)[0], 1))
463+
A_new = A_new - tl.tenalg.outer((projection, max_row))
467464

468465
# Delete the selected row
469466
mask.pop(tl.to_numpy(max_row_idx))

0 commit comments

Comments
 (0)