Module data: functions for working with datasets

Package teneva, module data: functions for working with datasets.

This module contains functions for working with datasets, including “accuracy_on_data” function.




teneva_jax.data.accuracy_on_data(Y, I_data, y_data)[source]

Compute the relative error of TT-tensor on the dataset.

Parameters:
  • I_data (jnp.ndarray) – multi-indices for items of dataset in the form of array of the shape [samples, d].

  • y_data (jnp.ndarray) – values for items related to I_data of dataset in the form of array of the shape [samples].

Returns:

the relative error.

Return type:

jnp.ndarray of size 1

Note

If I_data or y_data is not provided, the function will return -1.

Examples:

Let generate a random TT-tensor:

d = 20  # Dimension of the tensor
n = 10  # Mode size of the tensor
r = 2   # TT-rank of the tensor
rng, key = jax.random.split(rng)
Y = teneva.rand(d, n, r, key)
Then we generate some random multi-indices, compute related

tensor values and add some noise:

m = 100 # Size of the dataset
I_data = teneva_base.sample_lhs([n]*d, m)
y_data = teneva.get_many(Y, I_data)

rng, key = jax.random.split(rng)
y_data = y_data + 1.E-5*jax.random.normal(key, (m, ))

And then let compute the accuracy:

eps = teneva.accuracy_on_data(Y, I_data, y_data)

print(f'Accuracy     : {eps:-8.2e}')

# >>> ----------------------------------------
# >>> Output:

# Accuracy     : 2.34e-03
#