Module als: construct TT-tensor by TT-ALS¶
Package teneva, module als: construct TT-tensor, using TT-ALS.
This module contains the function “als” which computes the TT-approximation for the tensor by TT-ALS algorithm, using given random samples (i.e., the set of random tensor multi-indices and related tensor values).
- teneva_jax.als.als(I_trn, y_trn, Y0, nswp=50)[source]¶
Build TT-tensor by TT-ALS method using given random tensor samples.
Note that this function uses inner jax.jit calls. It is not recommended to “jax.jit” this function while calls.
- Parameters:
I_trn (jnp.ndarray) – multi-indices for the tensor in the form of array of the shape [samples, d], where d is a number of tensor’s dimensions and samples is a size of the train dataset.
y_trn (jnp.ndarray) – values of the tensor for multi-indices I_trn in the form of array of the shape [samples].
Y0 (list) – TT-tensor, which is the initial approximation for algorithm.
nswp (int) – number of ALS iterations (sweeps).
- Returns:
TT-tensor, which represents the TT-approximation for the tensor.
- Return type:
list
Examples:
d = 20 # Dimension of the function n = 10 # Shape of the tensor r = 5 # TT-rank of the initial random tensor nswp = 50 # Sweep number for ALS iterations m = int(1.E+5) # Number of calls to target function m_tst = int(1.E+4) # Number of test points
We set the target function (the function takes as input a multi-index i of the shape [dimension], which is transformed into point x of a uniform spatial grid):
a = -2.048 # Lower bound for the spatial grid b = +2.048 # Upper bound for the spatial grid def func_base(i): """Michalewicz function.""" x = i / n * (b - a) + a y1 = 100. * (x[1:] - x[:-1]**2)**2 y2 = (x[:-1] - 1.)**2 return jnp.sum(y1 + y2) y1 = jnp.sin(((jnp.arange(d) + 1) * x**2 / jnp.pi)) return -jnp.sum(jnp.sin(x) * y1**(2 * 10)) func = jax.vmap(func_base)
We prepare train data from the LHS random distribution:
rng, key = jax.random.split(rng) I_trn = teneva.sample_lhs(d, n, m, key) y_trn = func(I_trn)
We prepare test data from a random tensor multi-indices:
rng, key = jax.random.split(rng) I_tst = teneva.sample_rand(d, n, m_tst, key) y_tst = func(I_tst)
We build the initial approximation by the TT-ANOVA method:
# TODO: replace with jax-version! Y_anova_base = teneva_base.anova(I_trn, y_trn, r) Y_anova = teneva.convert(Y_anova_base)
And now we will build the TT-tensor, which approximates the target function by the TT-ALS method:
t = tpc() Y = teneva.als(I_trn, y_trn, Y_anova, nswp) t = tpc() - t print(f'Build time : {t:-10.2f}') # >>> ---------------------------------------- # >>> Output: # Build time : 34.58 #
We can check the accuracy of the result:
# Compute approximation in train points: y_our = teneva.get_many(Y, I_trn) # Accuracy of the result for train points: e_trn = jnp.linalg.norm(y_our - y_trn) e_trn /= jnp.linalg.norm(y_trn) # Compute approximation in test points: y_our = teneva.get_many(Y, I_tst) # Accuracy of the result for test points: e_tst = jnp.linalg.norm(y_our - y_tst) e_tst /= jnp.linalg.norm(y_tst) print(f'Error on train : {e_trn:-10.2e}') print(f'Error on test : {e_tst:-10.2e}') # >>> ---------------------------------------- # >>> Output: # Error on train : 7.16e-05 # Error on test : 7.80e-05 #
We can compare the result with the base (numpy) ALS method (we run it on the same train data with the same initial approximation and parameters):
t = tpc() Y = teneva_base.als(I_trn, y_trn, Y_anova_base, nswp, e=-1.) t = tpc() - t print(f'Build time : {t:-10.2f}') # Compute approximation in train points: y_our = teneva_base.get_many(Y, I_trn) # Accuracy of the result for train points: e_trn = jnp.linalg.norm(y_our - y_trn) e_trn /= jnp.linalg.norm(y_trn) # Compute approximation in test points: y_our = teneva_base.get_many(Y, I_tst) # Accuracy of the result for test points: e_tst = jnp.linalg.norm(y_our - y_tst) e_tst /= jnp.linalg.norm(y_tst) print(f'Error on train : {e_trn:-10.2e}') print(f'Error on test : {e_tst:-10.2e}') # >>> ---------------------------------------- # >>> Output: # Build time : 86.45 # Error on train : 1.04e-03 # Error on test : 1.18e-03 #