Source code for teneva_jax.vis

"""Package teneva, module vis: visualization methods for tensors.

This module contains the functions for visualization of TT-tensors.

"""
import jax.numpy as jnp


[docs]def show(Y): """Check and display mode size and TT-rank of the given TT-tensor. Args: Y (list): TT-tensor. """ if not isinstance(Y, list) or len(Y) != 3: raise ValueError('Invalid TT-tensor') Yl, Ym, Yr = Y if not isinstance(Yl, jnp.ndarray) or len(Yl.shape) != 3: raise ValueError('Invalid left core of TT-tensor') if not isinstance(Ym, jnp.ndarray) or len(Ym.shape) != 4: raise ValueError('Invalid middle cores of TT-tensor') if not isinstance(Yr, jnp.ndarray) or len(Yr.shape) != 3: raise ValueError('Invalid right core of TT-tensor') if Ym.shape[1] != Ym.shape[3]: raise ValueError('Invalid shape of middle cores for TT-tensor') d = Ym.shape[0] + 2 n = Ym.shape[2] r = Ym.shape[3] if r > n: raise ValueError('TT-rank should be no greater than mode size') if Yl.shape[0] != 1 or Yl.shape[1] != n or Yl.shape[2] != r: raise ValueError('Invalid shape of left core for TT-tensor') if Yr.shape[0] != r or Yr.shape[1] != n or Yr.shape[2] != 1: raise ValueError('Invalid shape of right core for TT-tensor') text = f'TT-tensor-jax | d = {d:-5d} | n = {n:-5d} | r = {r:-5d} |' print(text)