Module als: construct TT-tensor by TT-ALS

Package teneva, module als: construct TT-tensor, using TT-ALS.

This module contains the function “als” which computes the TT-approximation for the tensor by TT-ALS algorithm, using given random samples (i.e., the set of random tensor multi-indices and related tensor values).




teneva_jax.als.als(I_trn, y_trn, Y0, nswp=50)[source]

Build TT-tensor by TT-ALS method using given random tensor samples.

Note that this function uses inner jax.jit calls. It is not recommended to “jax.jit” this function while calls.

Parameters:
  • I_trn (jnp.ndarray) – multi-indices for the tensor in the form of array of the shape [samples, d], where d is a number of tensor’s dimensions and samples is a size of the train dataset.

  • y_trn (jnp.ndarray) – values of the tensor for multi-indices I_trn in the form of array of the shape [samples].

  • Y0 (list) – TT-tensor, which is the initial approximation for algorithm.

  • nswp (int) – number of ALS iterations (sweeps).

Returns:

TT-tensor, which represents the TT-approximation for the tensor.

Return type:

list

Examples:

d = 20             # Dimension of the function
n = 10             # Shape of the tensor
r = 5              # TT-rank of the initial random tensor
nswp = 50          # Sweep number for ALS iterations
m = int(1.E+5)     # Number of calls to target function
m_tst = int(1.E+4) # Number of test points

We set the target function (the function takes as input a multi-index i of the shape [dimension], which is transformed into point x of a uniform spatial grid):

a = -2.048 # Lower bound for the spatial grid
b = +2.048 # Upper bound for the spatial grid

def func_base(i):
    """Michalewicz function."""
    x = i / n * (b - a) + a
    y1 = 100. * (x[1:] - x[:-1]**2)**2
    y2 = (x[:-1] - 1.)**2
    return jnp.sum(y1 + y2)

    y1 = jnp.sin(((jnp.arange(d) + 1) * x**2 / jnp.pi))
    return -jnp.sum(jnp.sin(x) * y1**(2 * 10))

func = jax.vmap(func_base)

We prepare train data from the LHS random distribution:

rng, key = jax.random.split(rng)
I_trn = teneva.sample_lhs(d, n, m, key)
y_trn = func(I_trn)

We prepare test data from a random tensor multi-indices:

rng, key = jax.random.split(rng)
I_tst = teneva.sample_rand(d, n, m_tst, key)
y_tst = func(I_tst)

We build the initial approximation by the TT-ANOVA method:

# TODO: replace with jax-version!
Y_anova_base = teneva_base.anova(I_trn, y_trn, r)
Y_anova = teneva.convert(Y_anova_base)

And now we will build the TT-tensor, which approximates the target function by the TT-ALS method:

t = tpc()
Y = teneva.als(I_trn, y_trn, Y_anova, nswp)
t = tpc() - t

print(f'Build time     : {t:-10.2f}')

# >>> ----------------------------------------
# >>> Output:

# Build time     :      34.58
#

We can check the accuracy of the result:

# Compute approximation in train points:
y_our = teneva.get_many(Y, I_trn)

# Accuracy of the result for train points:
e_trn = jnp.linalg.norm(y_our - y_trn)
e_trn /= jnp.linalg.norm(y_trn)

# Compute approximation in test points:
y_our = teneva.get_many(Y, I_tst)

# Accuracy of the result for test points:
e_tst = jnp.linalg.norm(y_our - y_tst)
e_tst /= jnp.linalg.norm(y_tst)

print(f'Error on train : {e_trn:-10.2e}')
print(f'Error on test  : {e_tst:-10.2e}')

# >>> ----------------------------------------
# >>> Output:

# Error on train :   7.16e-05
# Error on test  :   7.80e-05
#

We can compare the result with the base (numpy) ALS method (we run it on the same train data with the same initial approximation and parameters):

t = tpc()
Y = teneva_base.als(I_trn, y_trn, Y_anova_base, nswp, e=-1.)
t = tpc() - t

print(f'Build time     : {t:-10.2f}')

# Compute approximation in train points:
y_our = teneva_base.get_many(Y, I_trn)

# Accuracy of the result for train points:
e_trn = jnp.linalg.norm(y_our - y_trn)
e_trn /= jnp.linalg.norm(y_trn)

# Compute approximation in test points:
y_our = teneva_base.get_many(Y, I_tst)

# Accuracy of the result for test points:
e_tst = jnp.linalg.norm(y_our - y_tst)
e_tst /= jnp.linalg.norm(y_tst)

print(f'Error on train : {e_trn:-10.2e}')
print(f'Error on test  : {e_tst:-10.2e}')

# >>> ----------------------------------------
# >>> Output:

# Build time     :      86.45
# Error on train :   1.04e-03
# Error on test  :   1.18e-03
#