Source code for teneva_jax.tensors

"""Package teneva, module tensors: various useful TT-tensors.

This module contains the collection of functions for explicit construction of
various useful TT-tensors (only random tensor for now).

"""
import jax


[docs]def rand(d, n, r, key, a=-1., b=1.): """Construct a random TT-tensor from the uniform distribution. Args: d (int): number of tensor dimensions. n (int): mode size of the tensor. r (int): TT-rank of the tensor. key (jax.random.PRNGKey): jax random key. a (float): minimum value for random items of the TT-cores. b (float): maximum value for random items of the TT-cores. Returns: list: TT-tensor. """ keyl, keym, keyr = jax.random.split(key, 3) Yl = jax.random.uniform(keyl, (1, n, r), minval=a, maxval=b) Ym = jax.random.uniform(keym, (d-2, r, n, r), minval=a, maxval=b) Yr = jax.random.uniform(keyr, (r, n, 1), minval=a, maxval=b) return [Yl, Ym, Yr]
[docs]def rand_norm(d, n, r, key, m=0., s=1.): """Construct a random TT-tensor from the normal distribution. Args: d (int): number of tensor dimensions. n (int): mode size of the tensor. r (int): TT-rank of the tensor. key (jax.random.PRNGKey): jax random key. m (float): mean ("centre") of the distribution. s (float): standard deviation of the distribution (>0). Returns: list: TT-tensor. """ keyl, keym, keyr = jax.random.split(key, 3) Yl = m + s * jax.random.normal(keyl, (1, n, r)) Ym = m + s * jax.random.normal(keym, (d-2, r, n, r)) Yr = m + s * jax.random.normal(keyr, (r, n, 1)) return [Yl, Ym, Yr]