Source code for teneva_jax.cross

"""Package teneva, module cross: construct TT-tensor, using TT-cross.

This module contains the function "cross" which computes the TT-approximation
for implicit tensor given functionally  by the rank-adaptive multidimensional
cross approximation method in the TT-format (TT-cross).

"""
import jax
import jax.numpy as jnp
import teneva_jax as teneva


from functools import partial


[docs]def cross(f, Y0, nswp=10): """Compute the TT-approximation for implicit tensor given functionally. DRAFT (works with error now) !!! This function computes the TT-approximation for implicit tensor given functionally by the multidimensional cross approximation method in the TT-format (TT-cross). Note that the "f" function is expected to be jitted. Args: f (function): function f(I) which computes tensor elements for the given set of multi-indices I, where I is a 2D jnp.ndarray of the shape [samples, dimensions]. The function should return 1D jnp.ndarray of the length equals to samples, which relates to the values of the target function for all provided samples. Y0 (list): TT-tensor, which is the initial approximation for algorithm. nswp (int): maximum number of iterations (sweeps) of the algorithm. One sweep corresponds to a complete pass of all tensor TT-cores from left to right and then from right to left. Returns: list: TT-Tensor which approximates the implicit tensor. """ Yl, Ym, Yr = teneva.copy(Y0) d = len(Ym) + 2 n = Yl.shape[1] #Ir = [jnp.zeros((1, 0)) for i in range(d+1)] #Ic = [jnp.zeros((1, 0)) for i in range(d+1)] #Y = teneva.convert(Y) @jax.jit def _iter_rtl_body_pre(args, G): R, Ic, d, k = args r = Ic.shape[0] Ic = Ic[:, :k] G = jnp.tensordot(G, R, 1) G, R, Ic = _iter_rtl(G, Ic) Ic = jnp.hstack(Ic, jnp.zeros((r, d-k-1))) return (R, Ic, d, k+1), (G, Ic) R = jnp.ones((1, 1)) Icr = jnp.zeros((1, 0)) (R, _, _, _), (Yr, Icr) = _iter_rtl_body_pre( (R, Icr, d, 0), Yr) (R, _, _, _), (Ym, Icm) = jax.lax.scan(_iter_rtl_body_pre, (R, Icr, d, 1), Ym, reverse=True) Icl, Icm = _shift_rtl(Icm, Icr) (R, _, _, _), (Yl, _) = _iter_rtl_body_pre( (R, Icl, d, d-1), Yl) Yl = jnp.tensordot(R, Yl, 1) for fff in Ic: print(fff) return Yl, Ym, Yr R = jnp.ones((1, 1)) for i in range(d-1, -1, -1): G = jnp.tensordot(Y[i], R, 1) Y[i], R, Ic[i] = _iter_rtl(G, Ic[i+1]) Y[0] = jnp.tensordot(R, Y[0], 1) Icl = Ic[1] Icm = jnp.vstack(Ic[2:-1]) @partial(jax.jit, static_argnums=[2]) def _func(Ir, Ic, ig): n, r1, r2 = ig.shape[0], Ir.shape[0], Ic.shape[0] I = jnp.kron(jnp.kron(jnp.ones(r2), ig), jnp.ones(r1)).reshape((-1,1)) I = jnp.hstack((jnp.kron(jnp.ones((n*r2, 1)), Ir), I)) I = jnp.hstack((I, jnp.kron(Ic, jnp.ones((r1*n, 1))))) return jnp.reshape(f(I), (r1, n, r2), order='F') @jax.jit def _iter_ltr_body(args, Ic): R, Ir, ig = args #Z = _func(Ir, Ic, ig) n, r1, r2 = ig.shape[0], Ir.shape[0], Ic.shape[0] I = jnp.kron(jnp.kron(jnp.ones(r2), ig), jnp.ones(r1)).reshape((-1,1)) I = jnp.hstack((jnp.kron(jnp.ones((n*r2, 1)), Ir), I)) I = jnp.hstack((I, jnp.kron(Ic, jnp.ones((r1*n, 1))))) Z = jnp.reshape(f(I), (r1, n, r2), order='F') G, R, Ir = _iter_ltr(Z, Ir) return (R, Ir, ig), (G, Ir) @jax.jit def _iter_rtl_body(args, Ir): R, Ic, ig = args #Z = _func(Ir, Ic, ig) n, r1, r2 = ig.shape[0], Ir.shape[0], Ic.shape[0] I = jnp.kron(jnp.kron(jnp.ones(r2), ig), jnp.ones(r1)).reshape((-1,1)) I = jnp.hstack((jnp.kron(jnp.ones((n*r2, 1)), Ir), I)) I = jnp.hstack((I, jnp.kron(Ic, jnp.ones((r1*n, 1))))) Z = jnp.reshape(f(I), (r1, n, r2), order='F') G, R, Ic = _iter_rtl(Z, Ic) return (R, Ic, ig), (G, Ic) ig = jnp.arange(n) for _ in range(nswp): (R, _, _), (Yl, Irl) = _iter_ltr_body( (None, jnp.zeros((1, 0)), ig), Icl) (R, _, _), (Ym, Irm) = jax.lax.scan(_iter_ltr_body, (R, Irl, ig), Icm) Irm, Irr = _shift_ltr(Irl, Irm) (R, _, _), (Yr, _) = _iter_ltr_body( (R, Irr, ig), jnp.zeros((1, 0))) Yr = jnp.tensordot(Yr, R, 1) (R, _, _), (Yr, Icr) = _iter_rtl_body( (None, Irr, ig), jnp.zeros((1, 0))) (R, _, _), (Ym, Icm) = jax.lax.scan(_iter_rtl_body, (R, Icr, ig), Irm, reverse=True) Icl, Icm = _shift_rtl(Icm, Icr) (R, _, _), (Yl, _) = _iter_rtl_body( (R, Icl, ig), jnp.zeros((1, 0))) Yl = jnp.tensordot(R, Yl, 1) import numpy as onp Y = [onp.array(G) for G in Y] return teneva.convert(Y)
def _iter_ltr(Z, Ir): r1, n, r2 = Z.shape I = jnp.kron(jnp.arange(n), jnp.ones(r1)).reshape((-1,1)) I = jnp.hstack((jnp.kron(jnp.ones((n, 1)), Ir), I)) Q, R = jnp.linalg.qr(jnp.reshape(Z, (r1 * n, r2), order='F')) ind, B = teneva.maxvol(Q) G = jnp.reshape(B, (r1, n, -1), order='F') R = Q[ind, :] @ R return G, R, I[ind, :] def _iter_rtl(Z, Il): r1, n, r2 = Z.shape I = jnp.kron(jnp.ones(r2), jnp.arange(n)).reshape((-1,1)) I = jnp.hstack((I, jnp.kron(Il, jnp.ones((n, 1))))) Q, R = jnp.linalg.qr(jnp.reshape(Z, (r1, n * r2), order='F').T) ind, B = teneva.maxvol(Q) G = jnp.reshape(B.T, (-1, n, r2), order='F') R = (Q[ind, :] @ R).T return G, R, I[ind, :] @jax.jit def _shift_ltr(Zl_ltr, Zm_ltr): return jnp.vstack((Zl_ltr[None], Zm_ltr[:-1])), Zm_ltr[-1] @jax.jit def _shift_rtl(Zm_rtl, Zr_rtl): return Zm_rtl[0], jnp.vstack((Zm_rtl[1:], Zr_rtl[None])) def cross_1(f, Y0, nswp=10): Y = teneva.copy(Y0) d = len(Y[1]) + 2 n = Y[0].shape[1] Ir = [jnp.zeros((1, 0)) for i in range(d+1)] Ic = [jnp.zeros((1, 0)) for i in range(d+1)] Y = teneva.convert(Y) R = jnp.ones((1, 1)) for i in range(d-1, -1, -1): G = jnp.tensordot(Y[i], R, 1) Y[i], R, Ic[i] = _iter_rtl(G, Ic[i+1]) Y[0] = jnp.tensordot(R, Y[0], 1) def _func(n, Ir, Ic): r1, r2 = Ir.shape[0], Ic.shape[0] I = jnp.kron(jnp.kron(jnp.ones(r2), jnp.arange(n)), jnp.ones(r1)).reshape((-1,1)) I = jnp.hstack((jnp.kron(jnp.ones((n*r2, 1)), Ir), I)) I = jnp.hstack((I, jnp.kron(Ic, jnp.ones((r1*n, 1))))) return jnp.reshape(f(I), (r1, n, r2), order='F') def _iter_ltr_body(Ir, Ic): Z = _func(n, Ir, Ic) G, R, Ir = _iter_ltr(Z, Ir) return G, R, Ir def _iter_rtl_body(Ir, Ic): Z = _func(n, Ir, Ic) G, R, Ic = _iter_rtl(Z, Ic) return G, R, Ic Icl = Ic[1] Icm = Ic[2:] for _ in range(nswp): Y[0], R, Ir[1] = _iter_ltr_body(jnp.zeros((1, 0)), Ic[1]) for i in range(1, d-1): Y[i], R, Ir[i+1] = _iter_ltr_body(Ir[i], Ic[i+1]) Y[d-1], R, Ir[d] = _iter_ltr_body(Ir[d-1], jnp.zeros((1, 0))) Y[d-1] = jnp.tensordot(Y[d-1], R, 1) Y[d-1], R, Ic[d-1] = _iter_rtl_body(Ir[d-1], jnp.zeros((1, 0))) for i in range(d-2, 0, -1): Y[i], R, Ic[i] = _iter_rtl_body(Ir[i], Ic[i+1]) Y[0], R, Ic[0] = _iter_rtl_body(jnp.zeros((1, 0)), Ic[1]) Y[0] = jnp.tensordot(R, Y[0], 1) import numpy as onp Y = [onp.array(G) for G in Y] return teneva.convert(Y) def _iter_ltr(Z, Ir): r1, n, r2 = Z.shape I = jnp.kron(jnp.arange(n), jnp.ones(r1)).reshape((-1,1)) I = jnp.hstack((jnp.kron(jnp.ones((n, 1)), Ir), I)) Q, R = jnp.linalg.qr(jnp.reshape(Z, (r1 * n, r2), order='F')) ind, B = teneva.maxvol(Q) G = jnp.reshape(B, (r1, n, -1), order='F') R = Q[ind, :] @ R return G, R, I[ind, :] def _iter_rtl(Z, Il): r1, n, r2 = Z.shape I = jnp.kron(jnp.ones(r2), jnp.arange(n)).reshape((-1,1)) I = jnp.hstack((I, jnp.kron(Il, jnp.ones((n, 1))))) Q, R = jnp.linalg.qr(jnp.reshape(Z, (r1, n * r2), order='F').T) ind, B = teneva.maxvol(Q) G = jnp.reshape(B.T, (-1, n, r2), order='F') R = (Q[ind, :] @ R).T return G, R, I[ind, :] if __name__ == '__main__': import jax import jax.numpy as jnp import teneva as teneva_base import teneva.core_jax as teneva from time import perf_counter as tpc rng = jax.random.PRNGKey(42) from jax.config import config config.update('jax_enable_x64', True) d = 10 # Dimension of the function n = 5 # Shape of the tensor r = 3 # TT-rank of the initial random tensor nswp = 5 # Sweep number for TT-cross iterations m_tst = int(1.E+4) # Number of test points a = -2.048 # Lower bound for the spatial grid b = +2.048 # Upper bound for the spatial grid def func_base(i): """Michalewicz function.""" x = i / n * (b - a) + a y1 = 100. * (x[1:] - x[:-1]**2)**2 y2 = (x[:-1] - 1.)**2 return jnp.sum(y1 + y2) y1 = jnp.sin(((jnp.arange(d) + 1) * x**2 / jnp.pi)) return -jnp.sum(jnp.sin(x) * y1**(2 * 10)) func = jax.jit(jax.vmap(func_base)) rng, key = jax.random.split(rng) I_tst = teneva.sample_rand(d, n, m_tst, key) y_tst = func(I_tst) rng, key = jax.random.split(rng) Y0 = teneva.rand(d, n, r, key) t = tpc() Y = cross(func, Y0, nswp) t = tpc() - t print(f'Build time : {t:-10.2f}') # Compute approximation in test points: y_our = teneva.get_many(Y, I_tst) # Accuracy of the result for test points: e_tst = jnp.linalg.norm(y_our - y_tst) / jnp.linalg.norm(y_tst) print(f'Error on test : {e_tst:-10.2e}')