Source code for teneva_jax.data
"""Package teneva, module data: functions for working with datasets.
This module contains functions for working with datasets, including
"accuracy_on_data" function.
"""
import jax
import jax.numpy as jnp
import teneva_jax as teneva
[docs]def accuracy_on_data(Y, I_data, y_data):
"""Compute the relative error of TT-tensor on the dataset.
Args:
I_data (jnp.ndarray): multi-indices for items of dataset in the form of
array of the shape [samples, d].
y_data (jnp.ndarray): values for items related to I_data of dataset in
the form of array of the shape [samples].
Returns:
jnp.ndarray of size 1: the relative error.
Note:
If I_data or y_data is not provided, the function will return -1.
"""
y = teneva.get_many(Y, I_data)
return jnp.linalg.norm(y - y_data) / jnp.linalg.norm(y_data)