"""Package teneva, module als: construct TT-tensor, using TT-ALS.
This module contains the function "als" which computes the TT-approximation for
the tensor by TT-ALS algorithm, using given random samples (i.e., the set of
random tensor multi-indices and related tensor values).
"""
import jax
import jax.numpy as jnp
import teneva_jax as teneva
[docs]def als(I_trn, y_trn, Y0, nswp=50):
"""Build TT-tensor by TT-ALS method using given random tensor samples.
Note that this function uses inner jax.jit calls. It is not recommended to
"jax.jit" this function while calls.
Args:
I_trn (jnp.ndarray): multi-indices for the tensor in the form of array
of the shape [samples, d], where d is a number of tensor's
dimensions and samples is a size of the train dataset.
y_trn (jnp.ndarray): values of the tensor for multi-indices I_trn in
the form of array of the shape [samples].
Y0 (list): TT-tensor, which is the initial approximation for algorithm.
nswp (int): number of ALS iterations (sweeps).
Returns:
list: TT-tensor, which represents the TT-approximation for the tensor.
"""
n = Y0[0].shape[1]
m, d = I_trn.shape
I_trn = I_trn.T
for k in range(d):
if jnp.unique(I_trn[k, :]).size != n:
raise ValueError('One groundtruth sample is needed for every slice')
inds = _build_indices(I_trn, n)
Y = teneva.copy(Y0)
Z_rtl = _iter_rtl_pre(Y[1], Y[2], I_trn)
return jax.lax.fori_loop(0, nswp, _iter, (Y, Z_rtl, I_trn, y_trn, inds))[0]
def _build_indices(I_trn, n):
# Precompute the data indices for each dimension and mode index:
d = I_trn.shape[0]
inds = []
lens = []
for k in range(d):
inds.append([])
for j in range(n):
inds[-1].append(jnp.where(I_trn[k, :] == j)[0])
lens.append(len(inds[-1][-1]))
l = min(lens)
for k in range(d):
for j in range(n):
inds[k][j] = inds[k][j][:l]
return jnp.array(inds)
@jax.jit
def _iter(swp, data):
Y, Z_rtl, I_trn, y_trn, inds = data
Y[:2], Z_ltr = _iter_ltr(Y[0], Y[1], Z_rtl, I_trn, y_trn, inds)
Y[1:], Z_rtl = _iter_rtl(Y[1], Y[2], Z_ltr, I_trn, y_trn, inds)
return Y, Z_rtl, I_trn, y_trn, inds
@jax.jit
def _iter_ltr(Yl, Ym, Z_rtl, I_trn, y_trn, inds):
d, m = I_trn.shape
Il_trn, Im_trn = I_trn[0], I_trn[1:-1]
yl_trn, ym_trn = y_trn, jnp.repeat(y_trn[:, None], d-2, axis=1).T
indsl, indsm = inds[0], inds[1:-1]
_, (Yl, Zl_ltr) = _body_ltr(
jnp.ones((m, 1)), (Yl, Z_rtl[0], Il_trn, yl_trn, indsl))
_, (Ym, Zm_ltr) = jax.lax.scan(_body_ltr,
Zl_ltr, (Ym, Z_rtl[1], Im_trn, ym_trn, indsm))
return (Yl, Ym), _shift_z_ltr(Zl_ltr, Zm_ltr)
@jax.jit
def _iter_rtl(Ym, Yr, Z_ltr, I_trn, y_trn, inds):
d, m = I_trn.shape
Im_trn, Ir_trn = I_trn[1:-1], I_trn[-1]
ym_trn, yr_trn = jnp.repeat(y_trn[:, None], d-2, axis=1).T, y_trn
indsm, indsr = inds[1:-1], inds[-1]
_, (Yr, Zr_rtl) = _body_rtl(
jnp.ones((1, m)), (Yr, Z_ltr[1], Ir_trn, yr_trn, indsr))
_, (Ym, Zm_rtl) = jax.lax.scan(_body_rtl,
Zr_rtl, (Ym, Z_ltr[0], Im_trn, ym_trn, indsm), reverse=True)
return (Ym, Yr), _shift_z_rtl(Zm_rtl, Zr_rtl)
@jax.jit
def _iter_rtl_pre(Ym, Yr, I_trn):
d, m = I_trn.shape
Im_trn, Ir_trn = I_trn[1:-1], I_trn[-1]
_, Zr_rtl = _body_rtl_pre(
jnp.ones((1, m)), (Yr, Ir_trn))
_, Zm_rtl = jax.lax.scan(_body_rtl_pre,
Zr_rtl, (Ym, Im_trn), reverse=True)
Zl_rtl, Zm_rtl = _shift_z_rtl(Zm_rtl, Zr_rtl)
return (Zl_rtl, Zm_rtl)
@jax.jit
def _body_ltr(Z_ltr, data):
G, Z_rtl, i, y, inds = data
G = G.swapaxes(0, 1)
_, G = jax.lax.scan(_optimize, (Z_ltr, Z_rtl, y), (G, inds))
G = G.swapaxes(0, 1)
Z_ltr = jnp.einsum('mq,qmr->mr', Z_ltr, G[:, i, :])
return Z_ltr, (G, Z_ltr)
@jax.jit
def _body_rtl(Z_rtl, data):
G, Z_ltr, i, y, inds = data
G = G.swapaxes(0, 1)
_, G = jax.lax.scan(_optimize, (Z_ltr, Z_rtl, y), (G, inds))
G = G.swapaxes(0, 1)
Z_rtl = jnp.einsum('rmq,qm->rm', G[:, i, :], Z_rtl)
return Z_rtl, (G, Z_rtl)
@jax.jit
def _body_rtl_pre(Z_rtl, data):
G, i = data
Z_rtl = jnp.einsum('rmq,qm->rm', G[:, i, :], Z_rtl)
return Z_rtl, Z_rtl
@jax.jit
def _optimize(args, data):
Z_ltr, Z_rtl, y = args
Q, idx = data
lhs = Z_rtl[:, idx].T[:, jnp.newaxis, :]
rhs = Z_ltr[idx, :][:, :, jnp.newaxis]
A = (lhs * rhs).reshape(len(idx), -1)
b = y[idx]
lamb = 0.001
AtA = A.T @ A
Aty = A.T @ b
sol = jnp.linalg.lstsq(AtA + lamb * jnp.identity(A.shape[1]), Aty)[0]
# sol = jnp.linalg.lstsq(A, b)[0]
Q = sol.reshape(Q.shape)
return (Z_ltr, Z_rtl, y), Q
@jax.jit
def _shift_z_ltr(Zl_ltr, Zm_ltr):
return jnp.vstack((Zl_ltr[None], Zm_ltr[:-1])), Zm_ltr[-1]
@jax.jit
def _shift_z_rtl(Zm_rtl, Zr_rtl):
return Zm_rtl[0], jnp.vstack((Zm_rtl[1:], Zr_rtl[None]))