"""Package teneva, module act_one: single TT-tensor operations.
This module contains the basic operations with one TT-tensor (Y), including
"copy", "get", "sum", etc.
"""
import jax
import jax.numpy as jnp
try:
import numpy as np
WITH_NUMPY = True
except Exception as e:
WITH_NUMPY = False
import teneva_jax as teneva
[docs]def convert(Y):
"""Convert TT-tensor from base (numpy) format and back.
Args:
Y (list): TT-tensor in numpy format (a list of d ordinary numpy arrays)
or in jax format (a list of 3 jax.numpy arrays).
Returns:
list: TT-tensor in numpy format if Y is in jax format and vice versa.
"""
if not isinstance(Y[0], jnp.ndarray): # Ordinary numpy format -> jax
Yl = jnp.array(Y[0], copy=True)
Ym = jnp.array(Y[1:-1], copy=True)
Yr = jnp.array(Y[-1], copy=True)
return [Yl, Ym, Yr]
else: # Jax format -> ordinary numpy format
if not WITH_NUMPY:
raise ValueError('Numpy is required for this function')
Yl, Ym, Yr = Y
Ym_base = jnp.split(Ym, Ym.shape[0])
Yl = np.array(Yl)
for k in range(len(Ym_base)):
Ym_base[k] = np.array(Ym_base[k][0])
Yr = np.array(Yr)
return [Yl] + Ym_base + [Yr]
[docs]def copy(Y):
"""Return a copy of the given TT-tensor.
Args:
Y (list): TT-tensor.
Returns:
list: TT-tensor, which is a copy of the given TT-tensor.
"""
return [Y[0].copy(), Y[1].copy(), Y[2].copy()]
[docs]def get(Y, k):
"""Compute the element of the TT-tensor.
Args:
Y (list): d-dimensional TT-tensor.
k (jnp.ndarray): the multi-index for the tensor of the length d.
Returns:
jnp.ndarray of size 1: the element of the TT-tensor.
"""
def body(q, data):
i, G = data
q = jnp.einsum('q,qr->r', q, G[:, i, :])
return q, None
Yl, Ym, Yr = Y
q = Yl[0, k[0], :]
q, _ = jax.lax.scan(body, q, (k[1:-1], Ym))
q, _ = body(q, (k[-1], Yr))
return q[0]
[docs]def get_log(Y, k):
"""Compute the logarithm of the element of the TT-tensor.
Args:
Y (list): d-dimensional TT-tensor.
k (jnp.ndarray): the multi-index for the tensor of the length d.
Returns:
jnp.ndarray of size 1: the logarithm of the element of the TT-tensor.
"""
def body(q, data):
i, G = data
Q = jnp.einsum('r,riq->iq', q, G)
p = jnp.sqrt(jnp.sum(Q**2, axis=1)[i])
q = (q @ G[:, i, :]) / p
return q, p
Yl, Ym, Yr = Y
q, pl = body(jnp.ones(1), (k[0], Yl))
q, pm = jax.lax.scan(body, q, (k[1:-1], Ym))
q, pr = body(q, (k[-1], Yr))
y = jnp.hstack((pl, pm, pr, jnp.linalg.norm(q)))
return jnp.sum(jnp.log(y))
[docs]def get_many(Y, K):
"""Compute the elements of the TT-tensor on many multi-indices.
Args:
Y (list): d-dimensional TT-tensor.
K (jnp.ndarray): the multi-indices for the tensor in the of the shape
[samples, d].
Returns:
jnp.ndarray: the elements of the TT-tensor for multi-indices K (array
of the length samples).
"""
def body(Q, data):
i, G = data
Q = jnp.einsum('kq,qkr->kr', Q, G[:, i, :])
return Q, None
Yl, Ym, Yr = Y
Q = Yl[0, K[:, 0], :]
Q, _ = jax.lax.scan(body, Q, (K[:, 1:-1].T, Ym))
Q, _ = body(Q, (K[:, -1], Yr))
return Q[:, 0]
[docs]def get_stab(Y, k):
"""Compute the element of the TT-tensor with stabilization factor.
Args:
Y (list): d-dimensional TT-tensor.
k (jnp.ndarray): the multi-index for the tensor of the length d.
Returns:
(jnp.ndarray of size 1, jnp.ndarray): the scaled value of the TT-tensor
v and stabilization factor p for each TT-core (array of the length d).
The resulting value is v * 2^{sum(p)}.
"""
def body(q, data):
i, G = data
q = jnp.einsum('q,qr->r', q, G[:, i, :])
q_max = jnp.max(jnp.abs(q))
p = jnp.floor(jnp.log2(q_max))
q = q / 2.**p
return q, p
Yl, Ym, Yr = Y
q, pl = Yl[0, k[0], :], 0
q, pm = jax.lax.scan(body, q, (k[1:-1], Ym))
q, pr = body(q, (k[-1], Yr))
return q[0], jnp.hstack((pl, pm, pr))
[docs]def grad(Y, k):
"""Compute gradients of the TT-tensor for given multi-index.
Args:
Y (list): d-dimensional TT-tensor.
k (list, jnp.ndarray): the multi-index for the tensor.
Returns:
list: the matrices which collect the gradients for all TT-cores.
Todo:
Move z construction into separate interface_* functions.
"""
def body_ltr(z, data):
G, i = data
z = z @ G[:, i, :]
return z, z
Yl, Ym, Yr = Y
z, zl = body_ltr(jnp.ones(1), (Yl, k[0]))
z, zm = jax.lax.scan(body_ltr, z, (Ym, k[1:-1]))
zm_ltr = jnp.vstack((zl, zm[:-1]))
zr_ltr = zm[-1]
def body_rtl(z, data):
G, i = data
z = G[:, i, :] @ z
return z, z
z, zr = body_rtl(jnp.ones(1), (Yr, k[-1]))
z, zm = jax.lax.scan(body_rtl, z, (Ym, k[1:-1]), reverse=True)
zl_rtl = zm[0]
zm_rtl = jnp.vstack((zm[1:], zr))
def body(z, data):
zl, zr = data
Gg = jnp.outer(zl, zr)
return None, Gg
_, Gl = body(None, (jnp.ones(1), zl_rtl))
_, Gm = jax.lax.scan(body, None, (zm_ltr, zm_rtl))
_, Gr = body(None, (zr_ltr, jnp.ones(1)))
return [Gl, Gm, Gr]
[docs]def interface_ltr(Y):
"""Generate the left to right interface vectors for the TT-tensor Y.
Args:
Y (list): d-dimensional TT-tensor.
Returns:
(list, list): inner interface vectors zl (list of arrrays of the length
d-2) and the right interface vector zr.
"""
def body(z, G):
z = z @ jnp.sum(G, axis=1)
z /= jnp.linalg.norm(z)
return z, z
Yl, Ym = Y[:-1]
z, zl = body(jnp.ones(1), Yl)
z, zm = jax.lax.scan(body, z, Ym)
zr = zm[-1]
zm = jnp.vstack((zl, zm[:-1]))
return zm, zr
[docs]def interface_rtl(Y):
"""Generate the right to left interface vectors for the TT-tensor Y.
Args:
Y (list): d-dimensional TT-tensor.
Returns:
(list, list): left interface vector zl and inner interface vectors zm
(list of arrrays of the length d-2).
"""
def body(z, G):
z = jnp.sum(G, axis=1) @ z
z /= jnp.linalg.norm(z)
return z, z
Ym, Yr = Y[1:]
z, zr = body(jnp.ones(1), Yr)
z, zm = jax.lax.scan(body, z, Ym, reverse=True)
zl = zm[0]
zm = jnp.vstack((zm[1:], zr))
return zl, zm
[docs]def mean(Y):
"""Compute mean value of the TT-tensor.
Args:
Y (list): TT-tensor.
Returns:
jnp.ndarray of size 1: the mean value of the TT-tensor.
"""
def scan(R, Y_cur):
k = Y_cur.shape[1]
q = jnp.ones(k) / k
R = R @ jnp.einsum('riq,i->rq', Y_cur, q)
return R, None
Yl, Ym, Yr = Y
R, _ = scan(jnp.ones((1, 1)), Yl)
R, _ = jax.lax.scan(scan, R, Ym)
R, _ = scan(R, Yr)
return R[0, 0]
[docs]def mean_stab(Y):
"""Compute mean value of the TT-tensor with stabilization factor.
Args:
Y (list): TT-tensor with d dimensions.
Returns:
(jnp.ndarray of size 1, jnp.ndarray): the scaled mean value of the
TT-tensor m and stabilization factor p for each TT-core (array of the
length d). The resulting value is m * 2^{sum(p)}.
"""
def scan(R, Y_cur):
k = Y_cur.shape[1]
Q = jnp.ones(k) / k
R = R @ jnp.einsum('riq,i->rq', Y_cur, Q)
r_max = jnp.max(jnp.abs(R))
p = jnp.floor(jnp.log2(r_max))
R = R / 2.**p
return R, p
Yl, Ym, Yr = Y
R, pl = scan(jnp.ones((1, 1)), Yl)
R, pm = jax.lax.scan(scan, R, Ym)
R, pr = scan(R, Yr)
return R[0, 0], jnp.hstack((pl, pm, pr))
[docs]def norm(Y, use_stab=False):
"""Compute Frobenius norm of the given TT-tensor.
Args:
Y (list): TT-tensor.
Returns:
jnp.ndarray of size 1: Frobenius norm of the TT-tensor.
Todo:
Check negative values from "mul_scalar".
"""
v = teneva.mul_scalar(Y, Y)
return jnp.sqrt(v)
[docs]def norm_stab(Y):
"""Compute Frobenius norm of the given TT-tensor with stab. factor.
Args:
Y (list): TT-tensor.
Returns:
(jnp.ndarray of size 1, list): Frobenius norm of the TT-tensor and
stabilization factor p for each TT-core.
Todo:
Check negative values from "mul_scalar".
"""
v, p = teneva.mul_scalar_stab(Y, Y)
return jnp.sqrt(v), p/2
[docs]def sum(Y):
"""Compute sum of all tensor elements.
Args:
Y (list): TT-tensor.
Returns:
jnp.ndarray of size 1: the sum of all tensor elements.
"""
def scan(R, Y_cur):
k = Y_cur.shape[1]
q = jnp.ones(k)
R = R @ jnp.einsum('riq,i->rq', Y_cur, q)
return R, None
Yl, Ym, Yr = Y
R, _ = scan(jnp.ones((1, 1)), Yl)
R, _ = jax.lax.scan(scan, R, Ym)
R, _ = scan(R, Yr)
return R[0, 0]
[docs]def sum_stab(Y):
"""Compute sum of all tensor elements with stabilization factor.
Args:
Y (list): TT-tensor with d dimensions.
Returns:
(jnp.ndarray of size 1, jnp.ndarray): the scaled sum of all TT-tensor
elements m and stabilization factor p for each TT-core (array of the
length d). The resulting value is m * 2^{sum(p)}.
"""
def scan(R, Y_cur):
k = Y_cur.shape[1]
Q = jnp.ones(k)
R = R @ jnp.einsum('rmq,m->rq', Y_cur, Q)
r_max = jnp.max(jnp.abs(R))
p = (jnp.floor(jnp.log2(r_max)))
R = R / 2.**p
return R, p
Yl, Ym, Yr = Y
R, pl = scan(jnp.ones((1, 1)), Yl)
R, pm = jax.lax.scan(scan, R, Ym)
R, pr = scan(R, Yr)
return R[0, 0], jnp.hstack((pl, pm, pr))
def _tt_tail_sizes(Y):
# TMP (will be removed?)
d = len(Y)
r = jnp.array([i.shape[0] for i in Y] + [Y[-1].shape[-1]])
idx_ch = jnp.arange(d)[r[1:] != r[:-1]]
if len(idx_ch) == 0:
return 0, 0
# now len(idx_ch) >= 2
i_longest = jnp.argmax(idx_ch[1:] - idx_ch[:-1])
return idx_ch[i_longest] + 1, d - idx_ch[i_longest + 1]