Module act_two: operations with a pair of TT-tensors¶
Package teneva, module act_two: operations with a pair of TT-tensors.
This module contains the basic operations with a pair of TT-tensors (Y1, Y2), including “add”, “mul”, “sub”, etc.
- teneva_jax.act_two.accuracy(Y1, Y2)[source]¶
Compute || Y1 - Y2 || / || Y2 || for tensors in the TT-format.
- Parameters:
Y1 (list) – TT-tensor.
Y2 (list) – TT-tensor.
- Returns:
the relative difference between two tensors.
- Return type:
jnp.ndarray of size 1
Examples:
d = 20 # Dimension of the tensor n = 10 # Mode size of the tensor r = 2 # TT-rank of the tensor
rng, key = jax.random.split(rng) Y1 = teneva.rand(d, n, r, key)
Let construct the TT-tensor Y2 = Y1 + eps * Y1 (eps = 1.E-4):
rng, key = jax.random.split(rng) Z2 = teneva.rand(d, n, r, key) Z2[0] = Z2[0] * 1.E-4 Y2 = teneva.add(Y1, Z2)
eps = teneva.accuracy(Y1, Y2) print(f'Accuracy : {eps.item():-8.2e}') # >>> ---------------------------------------- # >>> Output: # Accuracy : 1.08e-04 #
Note that this function works correctly even for very large dimension values due to the use of balancing (stabilization) in the scalar product:
for d in [10, 50, 100, 250, 1000, 10000]: rng, key = jax.random.split(rng) Y1 = teneva.rand(d, n, r, key) Y2 = teneva.add(Y1, Y1) eps = teneva.accuracy(Y1, Y2).item() print(f'd = {d:-5d} | eps = {eps:-8.1e} | expected value 0.5') # >>> ---------------------------------------- # >>> Output: # d = 10 | eps = 5.0e-01 | expected value 0.5 # d = 50 | eps = 5.0e-01 | expected value 0.5 # d = 100 | eps = 5.0e-01 | expected value 0.5 # d = 250 | eps = 5.0e-01 | expected value 0.5 # d = 1000 | eps = 5.0e-01 | expected value 0.5 # d = 10000 | eps = 5.0e-01 | expected value 0.5 #
- teneva_jax.act_two.add(Y1, Y2)[source]¶
Compute Y1 + Y2 in the TT-format.
- Parameters:
Y1 (list) – TT-tensor.
Y2 (list) – TT-tensor.
- Returns:
TT-tensor, which represents the element wise sum of Y1 and Y2.
- Return type:
list
Examples:
d = 5 # Dimension of the tensor n = 6 # Mode size of the tensor r1 = 2 # TT-rank of the 1th tensor r2 = 3 # TT-rank of the 2th tensor
rng, key = jax.random.split(rng) Y1 = teneva.rand(d, n, r1, key) rng, key = jax.random.split(rng) Y2 = teneva.rand(d, n, r2, key)
Y = teneva.add(Y1, Y2) teneva.show(Y) # Note that the result has TT-rank 2 + 3 = 5 # >>> ---------------------------------------- # >>> Output: # TT-tensor-jax | d = 5 | n = 6 | r = 5 | #
Let check the result:
Y1_full = teneva.full(Y1) Y2_full = teneva.full(Y2) Y_full = teneva.full(Y) Z_full = Y1_full + Y2_full # Compute error for TT-tensor vs full tensor: e = jnp.linalg.norm(Y_full - Z_full) e /= jnp.linalg.norm(Z_full) print(f'Error : {e:-8.2e}') # >>> ---------------------------------------- # >>> Output: # Error : 1.98e-16 #
- teneva_jax.act_two.mul(Y1, Y2)[source]¶
Compute element wise product Y1 * Y2 in the TT-format.
- Parameters:
Y1 (list) – TT-tensor.
Y2 (list) – TT-tensor.
- Returns:
TT-tensor, which represents the element wise product of Y1 and Y2.
- Return type:
list
Examples:
d = 5 # Dimension of the tensor n = 6 # Mode size of the tensor r1 = 2 # TT-rank of the 1th tensor r2 = 3 # TT-rank of the 2th tensor
rng, key = jax.random.split(rng) Y1 = teneva.rand(d, n, r1, key) rng, key = jax.random.split(rng) Y2 = teneva.rand(d, n, r2, key)
Y = teneva.mul(Y1, Y2) teneva.show(Y) # Note that the result has TT-rank 2 * 3 = 6 # >>> ---------------------------------------- # >>> Output: # TT-tensor-jax | d = 5 | n = 6 | r = 6 | #
Let check the result:
Y1_full = teneva.full(Y1) Y2_full = teneva.full(Y2) Y_full = teneva.full(Y) Z_full = Y1_full * Y2_full # Compute error for TT-tensor vs full tensor: e = jnp.linalg.norm(Y_full - Z_full) e /= jnp.linalg.norm(Z_full) print(f'Error : {e:-8.2e}') # >>> ---------------------------------------- # >>> Output: # Error : 2.81e-16 #
- teneva_jax.act_two.mul_scalar(Y1, Y2)[source]¶
Compute scalar product for Y1 and Y2 in the TT-format.
- Parameters:
Y1 (list) – TT-tensor.
Y2 (list) – TT-tensor.
- Returns:
the scalar product.
- Return type:
jnp.ndarray of size 1
Examples:
d = 5 # Dimension of the tensor n = 6 # Mode size of the tensor r1 = 2 # TT-rank of the 1th tensor r2 = 3 # TT-rank of the 2th tensor
rng, key = jax.random.split(rng) Y1 = teneva.rand(d, n, r1, key) rng, key = jax.random.split(rng) Y2 = teneva.rand(d, n, r2, key)
v = teneva.mul_scalar(Y1, Y2) print(v) # Print the resulting value # >>> ---------------------------------------- # >>> Output: # [7.55067038] #
Let check the result:
Y1_full = teneva.full(Y1) Y2_full = teneva.full(Y2) v_full = jnp.sum(Y1_full * Y2_full) print(v_full) # Print the resulting value from full tensor # Compute error for TT-tensor vs full tensor : e = jnp.abs((v - v_full)/v_full).item() print(f'Error : {e:-8.2e}') # >>> ---------------------------------------- # >>> Output: # 7.550670383793204 # Error : 1.06e-15 #
- teneva_jax.act_two.mul_scalar_stab(Y1, Y2)[source]¶
Compute scalar product for Y1 and Y2 in the TT-format with stab. factor.
- Parameters:
Y1 (list) – TT-tensor.
Y2 (list) – TT-tensor.
- Returns:
the scaled value of the scalar product and stabilization factor p for each TT-core (array of the length d). The resulting value is v * 2^{sum(p)}.
- Return type:
(jnp.ndarray of size 1, jnp.ndarray)
Examples:
d = 5 # Dimension of the tensor n = 6 # Mode size of the tensor r1 = 2 # TT-rank of the 1th tensor r2 = 3 # TT-rank of the 2th tensor
rng, key = jax.random.split(rng) Y1 = teneva.rand(d, n, r1, key) rng, key = jax.random.split(rng) Y2 = teneva.rand(d, n, r2, key)
v, p = teneva.mul_scalar_stab(Y1, Y2) print(v) # Print the scaled value print(p) # Print the scale factors v = v * 2**jnp.sum(p) # Resulting value print(v) # Print the resulting value # >>> ---------------------------------------- # >>> Output: # [-1.2784655] # [ 0. 1. 1. 1. -2.] # [-2.55693099] #
Let check the result:
Y1_full = teneva.full(Y1) Y2_full = teneva.full(Y2) v_full = jnp.sum(Y1_full * Y2_full) print(v_full) # Print the resulting value from full tensor # Compute error for TT-tensor vs full tensor : e = abs((v - v_full)/v_full).item() print(f'Error : {e:-8.2e}') # >>> ---------------------------------------- # >>> Output: # -2.556930991152627 # Error : 1.51e-14 #
- teneva_jax.act_two.sub(Y1, Y2)[source]¶
Compute Y1 - Y2 in the TT-format.
- Parameters:
Y1 (list) – TT-tensor.
Y2 (list) – TT-tensor.
- Returns:
TT-tensor, which represents the result of the operation Y1-Y2.
- Return type:
list
Examples:
d = 5 # Dimensions of the tensors n = 6 # Mode sizes of the tensors r1 = 2 # TT-rank of the 1th tensor r2 = 3 # TT-rank of the 2th tensor
rng, key = jax.random.split(rng) Y1 = teneva.rand(d, n, r1, key) rng, key = jax.random.split(rng) Y2 = teneva.rand(d, n, r2, key)
Y = teneva.sub(Y1, Y2) teneva.show(Y) # Note that the result has TT-rank 2 + 3 = 5 # >>> ---------------------------------------- # >>> Output: # TT-tensor-jax | d = 5 | n = 6 | r = 5 | #
Let check the result:
Y1_full = teneva.full(Y1) Y2_full = teneva.full(Y2) Y_full = teneva.full(Y) Z_full = Y1_full - Y2_full # Compute error for TT-tensor vs full tensor: e = jnp.linalg.norm(Y_full - Z_full) e /= jnp.linalg.norm(Z_full) print(f'Error : {e:-8.2e}') # >>> ---------------------------------------- # >>> Output: # Error : 1.77e-16 #