Source code for teneva_jax.act_two

"""Package teneva, module act_two: operations with a pair of TT-tensors.

This module contains the basic operations with a pair of TT-tensors (Y1, Y2),
including "add", "mul", "sub", etc.

"""
import jax
import jax.numpy as jnp
import teneva_jax as teneva


[docs]def accuracy(Y1, Y2): """Compute || Y1 - Y2 || / || Y2 || for tensors in the TT-format. Args: Y1 (list): TT-tensor. Y2 (list): TT-tensor. Returns: jnp.ndarray of size 1: the relative difference between two tensors. """ z1, p1 = teneva.norm_stab(sub(Y1, Y2)) z2, p2 = teneva.norm_stab(Y2) if (p1 - p2).sum() > +500: return 1.E+299 if (p1 - p2).sum() < -500: return 0. c = 2.**(p1 - p2).sum() if jnp.isinf(c) or jnp.isinf(z1) or jnp.isinf(z2) or abs(z2) < 1.E-100: return -1 # TODO: check return c * z1 / z2
[docs]def add(Y1, Y2): """Compute Y1 + Y2 in the TT-format. Args: Y1 (list): TT-tensor. Y2 (list): TT-tensor. Returns: list: TT-tensor, which represents the element wise sum of Y1 and Y2. """ def body(q, data): G1, G2 = data r1_l, n, r1_r = G1.shape r2_l, n, r2_r = G2.shape Z1 = jnp.zeros([r1_l, n, r2_r]) Z2 = jnp.zeros([r2_l, n, r1_r]) L1 = jnp.concatenate([G1, Z1], axis=2) L2 = jnp.concatenate([Z2, G2], axis=2) G = jnp.concatenate([L1, L2], axis=0) return None, G Yl1, Ym1, Yr1 = Y1 Yl2, Ym2, Yr2 = Y2 Yl = jnp.concatenate([Yl1, Yl2], axis=2) _, Ym = jax.lax.scan(body, None, (Ym1, Ym2)) Yr = jnp.concatenate([Yr1, Yr2], axis=0) return [Yl, Ym, Yr]
[docs]def mul(Y1, Y2): """Compute element wise product Y1 * Y2 in the TT-format. Args: Y1 (list): TT-tensor. Y2 (list): TT-tensor. Returns: list: TT-tensor, which represents the element wise product of Y1 and Y2. """ def body(q, data): G1, G2 = data G = G1[:, None, :, :, None] * G2[None, :, :, None, :] G = G.reshape([G1.shape[0]*G2.shape[0], -1, G1.shape[-1]*G2.shape[-1]]) return None, G Yl1, Ym1, Yr1 = Y1 Yl2, Ym2, Yr2 = Y2 _, Yl = body(None, (Yl1, Yl2)) _, Ym = jax.lax.scan(body, None, (Ym1, Ym2)) _, Yr = body(None, (Yr1, Yr2)) return [Yl, Ym, Yr]
[docs]def mul_scalar(Y1, Y2): """Compute scalar product for Y1 and Y2 in the TT-format. Args: Y1 (list): TT-tensor. Y2 (list): TT-tensor. Returns: jnp.ndarray of size 1: the scalar product. """ def body(q, data): G1, G2 = data G = G1[:, None, :, :, None] * G2[None, :, :, None, :] G = G.reshape([G1.shape[0]*G2.shape[0], -1, G1.shape[-1]*G2.shape[-1]]) G = jnp.sum(G, axis=1) q = q @ G return q, G Yl1, Ym1, Yr1 = Y1 Yl2, Ym2, Yr2 = Y2 q, _ = body(jnp.ones(1), (Yl1, Yl2)) q, _ = jax.lax.scan(body, q, (Ym1, Ym2)) q, _ = body(q, (Yr1, Yr2)) return q
[docs]def mul_scalar_stab(Y1, Y2): """Compute scalar product for Y1 and Y2 in the TT-format with stab. factor. Args: Y1 (list): TT-tensor. Y2 (list): TT-tensor. Returns: (jnp.ndarray of size 1, jnp.ndarray): the scaled value of the scalar product and stabilization factor p for each TT-core (array of the length d). The resulting value is v * 2^{sum(p)}. """ def body(q, data): G1, G2 = data G = G1[:, None, :, :, None] * G2[None, :, :, None, :] G = G.reshape([G1.shape[0]*G2.shape[0], -1, G1.shape[-1]*G2.shape[-1]]) G = jnp.sum(G, axis=1) q = q @ G q_max = jnp.max(jnp.abs(q)) p = (jnp.floor(jnp.log2(q_max))) q = q / 2.**p return q, p Yl1, Ym1, Yr1 = Y1 Yl2, Ym2, Yr2 = Y2 q, pl = body(jnp.ones(1), (Yl1, Yl2)) q, pm = jax.lax.scan(body, q, (Ym1, Ym2)) q, pr = body(q, (Yr1, Yr2)) return q, jnp.hstack((pl, pm, pr))
[docs]def sub(Y1, Y2): """Compute Y1 - Y2 in the TT-format. Args: Y1 (list): TT-tensor. Y2 (list): TT-tensor. Returns: list: TT-tensor, which represents the result of the operation Y1-Y2. """ Y2 = teneva.copy(Y2) Y2[0] *= -1. return add(Y1, Y2)