Module tensors: collection of explicit useful TT-tensors

Package teneva, module tensors: various useful TT-tensors.

This module contains the collection of functions for explicit construction of various useful TT-tensors (only random tensor for now).




teneva_jax.tensors.rand(d, n, r, key, a=-1.0, b=1.0)[source]

Construct a random TT-tensor from the uniform distribution.

Parameters:
  • d (int) – number of tensor dimensions.

  • n (int) – mode size of the tensor.

  • r (int) – TT-rank of the tensor.

  • key (jax.random.PRNGKey) – jax random key.

  • a (float) – minimum value for random items of the TT-cores.

  • b (float) – maximum value for random items of the TT-cores.

Returns:

TT-tensor.

Return type:

list

Examples:

d = 6                            # Dimension of the tensor
n = 5                            # Shape of the tensor
r = 4                            # TT-rank for the TT-tensor
rng, key = jax.random.split(rng)
Y = teneva.rand(d, n, r, key)    # Build the random TT-tensor
teneva.show(Y)                   # Print the resulting TT-tensor

# >>> ----------------------------------------
# >>> Output:

# TT-tensor-jax | d =     6 | n =     5 | r =     4 |
#

We may use custom limits:

d = 6                            # Dimension of the tensor
n = 5                            # Shape of the tensor
r = 4                            # TT-rank for the TT-tensor
a = 0.99                         # Minimum value
b = 1.                           # Maximum value
rng, key = jax.random.split(rng)
Y = teneva.rand(d, n, r, key, a, b)
print(Y[0])                      # Print the first TT-core

# >>> ----------------------------------------
# >>> Output:

# [[[0.99905933 0.99505376 0.99201173 0.99603783]
#   [0.9982403  0.99355506 0.9977989  0.99978416]
#   [0.99381576 0.99769924 0.99593848 0.99955382]
#   [0.99640582 0.99803304 0.99341177 0.99905888]
#   [0.99696002 0.99767435 0.99508183 0.99683427]]]
#


teneva_jax.tensors.rand_norm(d, n, r, key, m=0.0, s=1.0)[source]

Construct a random TT-tensor from the normal distribution.

Parameters:
  • d (int) – number of tensor dimensions.

  • n (int) – mode size of the tensor.

  • r (int) – TT-rank of the tensor.

  • key (jax.random.PRNGKey) – jax random key.

  • m (float) – mean (“centre”) of the distribution.

  • s (float) – standard deviation of the distribution (>0).

Returns:

TT-tensor.

Return type:

list

Examples:

d = 6                               # Dimension of the tensor
n = 5                               # Shape of the tensor
r = 4                               # TT-rank for the TT-tensor
rng, key = jax.random.split(rng)
Y = teneva.rand_norm(d, n, r, key)  # Build the random TT-tensor
teneva.show(Y)                      # Print the resulting TT-tensor

# >>> ----------------------------------------
# >>> Output:

# TT-tensor-jax | d =     6 | n =     5 | r =     4 |
#

We may use custom limits:

d = 6                               # Dimension of the tensor
n = 5                               # Shape of the tensor
r = 4                               # TT-rank for the TT-tensor
m = 42.                             # Mean ("centre")
s = 0.0001                          # Standard deviation
rng, key = jax.random.split(rng)
Y = teneva.rand_norm(d, n, r, key, m, s)
print(Y[0])                         # Print the first TT-core

# >>> ----------------------------------------
# >>> Output:

# [[[42.00022745 42.00018383 41.99995424 41.99999947]
#   [42.00010626 42.00004057 42.00015906 41.99983497]
#   [42.00001789 41.99989299 42.00008431 41.99996506]
#   [42.00011325 41.99989364 41.9999467  42.00013334]
#   [41.99989569 42.0000333  42.00003193 42.00000196]]]
#