Module act_two: operations with a pair of TT-tensors

Package teneva, module act_two: operations with a pair of TT-tensors.

This module contains the basic operations with a pair of TT-tensors (Y1, Y2), including “add”, “mul”, “sub”, etc.




teneva_jax.act_two.accuracy(Y1, Y2)[source]

Compute || Y1 - Y2 || / || Y2 || for tensors in the TT-format.

Parameters:
  • Y1 (list) – TT-tensor.

  • Y2 (list) – TT-tensor.

Returns:

the relative difference between two tensors.

Return type:

jnp.ndarray of size 1

Examples:

d = 20  # Dimension of the tensor
n = 10  # Mode size of the tensor
r = 2   # TT-rank of the tensor
rng, key = jax.random.split(rng)
Y1 = teneva.rand(d, n, r, key)

Let construct the TT-tensor Y2 = Y1 + eps * Y1 (eps = 1.E-4):

rng, key = jax.random.split(rng)
Z2 = teneva.rand(d, n, r, key)
Z2[0] = Z2[0] * 1.E-4

Y2 = teneva.add(Y1, Z2)
eps = teneva.accuracy(Y1, Y2)

print(f'Accuracy     : {eps.item():-8.2e}')

# >>> ----------------------------------------
# >>> Output:

# Accuracy     : 1.08e-04
#

Note that this function works correctly even for very large dimension values due to the use of balancing (stabilization) in the scalar product:

for d in [10, 50, 100, 250, 1000, 10000]:
    rng, key = jax.random.split(rng)
    Y1 = teneva.rand(d, n, r, key)
    Y2 = teneva.add(Y1, Y1)
    eps = teneva.accuracy(Y1, Y2).item()

    print(f'd = {d:-5d} | eps = {eps:-8.1e} | expected value 0.5')

# >>> ----------------------------------------
# >>> Output:

# d =    10 | eps =  5.0e-01 | expected value 0.5
# d =    50 | eps =  5.0e-01 | expected value 0.5
# d =   100 | eps =  5.0e-01 | expected value 0.5
# d =   250 | eps =  5.0e-01 | expected value 0.5
# d =  1000 | eps =  5.0e-01 | expected value 0.5
# d = 10000 | eps =  5.0e-01 | expected value 0.5
#


teneva_jax.act_two.add(Y1, Y2)[source]

Compute Y1 + Y2 in the TT-format.

Parameters:
  • Y1 (list) – TT-tensor.

  • Y2 (list) – TT-tensor.

Returns:

TT-tensor, which represents the element wise sum of Y1 and Y2.

Return type:

list

Examples:

d = 5   # Dimension of the tensor
n = 6   # Mode size of the tensor
r1 = 2  # TT-rank of the 1th tensor
r2 = 3  # TT-rank of the 2th tensor
rng, key = jax.random.split(rng)
Y1 = teneva.rand(d, n, r1, key)

rng, key = jax.random.split(rng)
Y2 = teneva.rand(d, n, r2, key)
Y = teneva.add(Y1, Y2)
teneva.show(Y)  # Note that the result has TT-rank 2 + 3 = 5

# >>> ----------------------------------------
# >>> Output:

# TT-tensor-jax | d =     5 | n =     6 | r =     5 |
#

Let check the result:

Y1_full = teneva.full(Y1)
Y2_full = teneva.full(Y2)
Y_full = teneva.full(Y)

Z_full = Y1_full + Y2_full

# Compute error for TT-tensor vs full tensor:
e = jnp.linalg.norm(Y_full - Z_full)
e /= jnp.linalg.norm(Z_full)
print(f'Error     : {e:-8.2e}')

# >>> ----------------------------------------
# >>> Output:

# Error     : 1.98e-16
#


teneva_jax.act_two.mul(Y1, Y2)[source]

Compute element wise product Y1 * Y2 in the TT-format.

Parameters:
  • Y1 (list) – TT-tensor.

  • Y2 (list) – TT-tensor.

Returns:

TT-tensor, which represents the element wise product of Y1 and Y2.

Return type:

list

Examples:

d = 5   # Dimension of the tensor
n = 6   # Mode size of the tensor
r1 = 2  # TT-rank of the 1th tensor
r2 = 3  # TT-rank of the 2th tensor
rng, key = jax.random.split(rng)
Y1 = teneva.rand(d, n, r1, key)

rng, key = jax.random.split(rng)
Y2 = teneva.rand(d, n, r2, key)
Y = teneva.mul(Y1, Y2)
teneva.show(Y)  # Note that the result has TT-rank 2 * 3 = 6

# >>> ----------------------------------------
# >>> Output:

# TT-tensor-jax | d =     5 | n =     6 | r =     6 |
#

Let check the result:

Y1_full = teneva.full(Y1)
Y2_full = teneva.full(Y2)
Y_full = teneva.full(Y)

Z_full = Y1_full * Y2_full

# Compute error for TT-tensor vs full tensor:
e = jnp.linalg.norm(Y_full - Z_full)
e /= jnp.linalg.norm(Z_full)
print(f'Error     : {e:-8.2e}')

# >>> ----------------------------------------
# >>> Output:

# Error     : 2.81e-16
#


teneva_jax.act_two.mul_scalar(Y1, Y2)[source]

Compute scalar product for Y1 and Y2 in the TT-format.

Parameters:
  • Y1 (list) – TT-tensor.

  • Y2 (list) – TT-tensor.

Returns:

the scalar product.

Return type:

jnp.ndarray of size 1

Examples:

d = 5   # Dimension of the tensor
n = 6   # Mode size of the tensor
r1 = 2  # TT-rank of the 1th tensor
r2 = 3  # TT-rank of the 2th tensor
rng, key = jax.random.split(rng)
Y1 = teneva.rand(d, n, r1, key)

rng, key = jax.random.split(rng)
Y2 = teneva.rand(d, n, r2, key)
v = teneva.mul_scalar(Y1, Y2)

print(v) # Print the resulting value

# >>> ----------------------------------------
# >>> Output:

# [7.55067038]
#

Let check the result:

Y1_full = teneva.full(Y1)
Y2_full = teneva.full(Y2)

v_full = jnp.sum(Y1_full * Y2_full)

print(v_full) # Print the resulting value from full tensor

# Compute error for TT-tensor vs full tensor :
e = jnp.abs((v - v_full)/v_full).item()
print(f'Error     : {e:-8.2e}')

# >>> ----------------------------------------
# >>> Output:

# 7.550670383793204
# Error     : 1.06e-15
#


teneva_jax.act_two.mul_scalar_stab(Y1, Y2)[source]

Compute scalar product for Y1 and Y2 in the TT-format with stab. factor.

Parameters:
  • Y1 (list) – TT-tensor.

  • Y2 (list) – TT-tensor.

Returns:

the scaled value of the scalar product and stabilization factor p for each TT-core (array of the length d). The resulting value is v * 2^{sum(p)}.

Return type:

(jnp.ndarray of size 1, jnp.ndarray)

Examples:

d = 5   # Dimension of the tensor
n = 6   # Mode size of the tensor
r1 = 2  # TT-rank of the 1th tensor
r2 = 3  # TT-rank of the 2th tensor
rng, key = jax.random.split(rng)
Y1 = teneva.rand(d, n, r1, key)

rng, key = jax.random.split(rng)
Y2 = teneva.rand(d, n, r2, key)
v, p = teneva.mul_scalar_stab(Y1, Y2)
print(v) # Print the scaled value
print(p) # Print the scale factors

v = v * 2**jnp.sum(p) # Resulting value
print(v)   # Print the resulting value

# >>> ----------------------------------------
# >>> Output:

# [-1.2784655]
# [ 0.  1.  1.  1. -2.]
# [-2.55693099]
#

Let check the result:

Y1_full = teneva.full(Y1)
Y2_full = teneva.full(Y2)

v_full = jnp.sum(Y1_full * Y2_full)

print(v_full) # Print the resulting value from full tensor

# Compute error for TT-tensor vs full tensor :
e = abs((v - v_full)/v_full).item()
print(f'Error     : {e:-8.2e}')

# >>> ----------------------------------------
# >>> Output:

# -2.556930991152627
# Error     : 1.51e-14
#


teneva_jax.act_two.sub(Y1, Y2)[source]

Compute Y1 - Y2 in the TT-format.

Parameters:
  • Y1 (list) – TT-tensor.

  • Y2 (list) – TT-tensor.

Returns:

TT-tensor, which represents the result of the operation Y1-Y2.

Return type:

list

Examples:

d = 5   # Dimensions of the tensors
n = 6   # Mode sizes of the tensors
r1 = 2  # TT-rank of the 1th tensor
r2 = 3  # TT-rank of the 2th tensor
rng, key = jax.random.split(rng)
Y1 = teneva.rand(d, n, r1, key)

rng, key = jax.random.split(rng)
Y2 = teneva.rand(d, n, r2, key)
Y = teneva.sub(Y1, Y2)
teneva.show(Y)  # Note that the result has TT-rank 2 + 3 = 5

# >>> ----------------------------------------
# >>> Output:

# TT-tensor-jax | d =     5 | n =     6 | r =     5 |
#

Let check the result:

Y1_full = teneva.full(Y1)
Y2_full = teneva.full(Y2)
Y_full = teneva.full(Y)

Z_full = Y1_full - Y2_full

# Compute error for TT-tensor vs full tensor:
e = jnp.linalg.norm(Y_full - Z_full)
e /= jnp.linalg.norm(Z_full)
print(f'Error     : {e:-8.2e}')

# >>> ----------------------------------------
# >>> Output:

# Error     : 1.77e-16
#