Module svd: SVD-based algorithms for matrices and tensors

Package teneva, module svd: SVD-based algorithms.

This module contains the basic implementation of the TT-SVD algorithm (function svd) as well as functions for constructing the skeleton decomposition (matrix_skeleton) for the matrices.




teneva_jax.svd.matrix_skeleton(A, r)[source]

Construct truncated skeleton decomposition A = U V for the given matrix.

Parameters:
  • A (jnp.ndarray) – matrix of the shape [m, n].

  • r (int) – rank for the truncated SVD decomposition.

Returns:

factor matrix U of the shape [m, r] and factor matrix V of the shape [r, n].

Return type:

(jnp.ndarray, jnp.ndarray)

Examples:

# Shape of the matrix:
m, n = 100, 30

# Build random matrix, which has rank 3 as a sum of rank-1 matrices:
rng, key = jax.random.split(rng)
keys = jax.random.split(key, 6)
u = [jax.random.normal(keys[i], (m, )) for i in range(3)]
v = [jax.random.normal(keys[i], (m, )) for i in range(3, 6)]
A = jnp.outer(u[0], v[0]) + jnp.outer(u[1], v[1]) + jnp.outer(u[2], v[2])
# Compute skeleton decomp.:
U, V = teneva.matrix_skeleton(A, r=3)

# Approximation error
e = jnp.linalg.norm(A - U @ V) / jnp.linalg.norm(A)

print(f'Shape of U :', U.shape)
print(f'Shape of V :', V.shape)
print(f'Error      : {e:-8.2e}')

# >>> ----------------------------------------
# >>> Output:

# Shape of U : (100, 3)
# Shape of V : (3, 100)
# Error      : 2.32e-15
#
# Compute skeleton decomp with small rank:
U, V = teneva.matrix_skeleton(A, r=2)

# Approximation error:
e = jnp.linalg.norm(A - U @ V) / jnp.linalg.norm(A)
print(f'Shape of U :', U.shape)
print(f'Shape of V :', V.shape)
print(f'Error      : {e:-8.2e}')

# >>> ----------------------------------------
# >>> Output:

# Shape of U : (100, 2)
# Shape of V : (2, 100)
# Error      : 4.49e-01
#


teneva_jax.svd.svd(Y_full, r)[source]

Construct TT-tensor from the given full tensor using TT-SVD algorithm.

Parameters:
  • Y_full (jnp.ndarray) – tensor (multidimensional array) in the full format.

  • r (int) – rank of the constructed TT-tensor.

Returns:

TT-tensor, which represents the rank-r TT-approximation.

Return type:

list

Note

This function does not take advantage of jax’s ability to speed up the code and can be slow, but it should only be meaningfully used for tensors of small dimensions.

Examples:

d = 5                # Dimension number
t = jnp.arange(2**d) # Tensor will be 2^d

# Construct d-dim full array:
Z_full = jnp.cos(t).reshape([2] * d, order='F')
# Construct TT-tensor by TT-SVD:
Y = teneva.svd(Z_full, r=2)

# Convert it back to numpy to check result:
Y_full = teneva.full(Y)

# Compute error for TT-tensor vs full tensor:
e = jnp.linalg.norm(Y_full - Z_full)
e /= jnp.linalg.norm(Z_full)
# Size of the original tensor:
print(f'Size (np) : {Z_full.size:-8d}')

# Size of the TT-tensor:
print(f'Size (tt) : {Y[0].size + Y[1].size + Y[2].size:-8d}') # TODO

# Rel. error for the TT-tensor vs full tensor:
print(f'Error     : {e:-8.2e}')

# >>> ----------------------------------------
# >>> Output:

# Size (np) :       32
# Size (tt) :       32
# Error     : 7.78e-16
#

We can also try a lower rank (it will lead to huge error in this case):

# Construct TT-tensor by TT-SVD:
Y = teneva.svd(Z_full, r=1)

# Convert it back to numpy to check result:
Y_full = teneva.full(Y)

# Compute error for TT-tensor vs full tensor:
e = jnp.linalg.norm(Y_full - Z_full)
e /= jnp.linalg.norm(Z_full)

print(f'Size (np) : {Z_full.size:-8d}')
print(f'Size (tt) : {Y[0].size + Y[1].size + Y[2].size:-8d}') # TODO
print(f'Error     : {e:-8.2e}')

# >>> ----------------------------------------
# >>> Output:

# Size (np) :       32
# Size (tt) :       10
# Error     : 7.13e-01
#

Note that in jax version rank can not be greater than mode size:

try:
    Y = teneva.svd(Z_full, r=3)
except ValueError as e:
    print('Error :', e)

# >>> ----------------------------------------
# >>> Output:

# Error : Rank can not be greater than mode size
#