Module sample: random sampling for/from the TT-tensor¶
Package teneva, module sample: random sampling for/from TT-tensor.
This module contains functions for sampling from the TT-tensor and for generation of random multi-indices and points for learning.
- teneva_jax.sample.sample(Y, zm, key)[source]¶
Sample according to given probability TT-tensor.
- Parameters:
Y (list) – TT-tensor, which represents the discrete probability distribution.
zm (list) – list of middle interface vectors for tensor Y. Run function “zl, zm = interface_rtl(Y)” to generate it and then use zm vector.
key (jax.random.PRNGKey) – jax random key.
- Returns:
generated multi-index for the tensor.
- Return type:
jnp.ndarray
Examples:
rng, key = jax.random.split(rng) Y = teneva.rand(d=8, n=5, r=4, key=key) zl, zm = teneva.interface_rtl(Y) rng, key = jax.random.split(rng) i = teneva.sample(Y, zm, key) print(i) # >>> ---------------------------------------- # >>> Output: # [0 4 1 4 0 2 4 1] #
And now let check this function for big random TT-tensor:
interface_rtl = jax.jit(teneva.interface_rtl) sample = jax.jit(jax.vmap(teneva.sample, (None, None, 0)))
rng, key = jax.random.split(rng) Y = teneva.rand(d=1000, n=100, r=10, key=key)
zl, zm = interface_rtl(Y) m = 10 # Number of samples rng, key = jax.random.split(rng) I = sample(Y, zm, jax.random.split(key, m)) for i in I: # i is a sample of the length d = 1000 print(len(i), jnp.mean(i)) # >>> ---------------------------------------- # >>> Output: # 1000 48.005 # 1000 48.943 # 1000 50.079 # 1000 50.75 # 1000 48.632 # 1000 49.833 # 1000 50.394 # 1000 49.366 # 1000 49.688 # 1000 49.441 #
Let compare this function with numpy realization:
d = 25 # Dimension of the tensor n = 10 # Mode size of the tensor r = 5 # Rank of the tensor m = 100000 # Number of samples
Y_base = teneva_base.rand([n]*d, r)
t = tpc() I_base = teneva_base.sample(Y_base, m) t = tpc() - t print(f'Time : {t:-8.2f}') print(f'Mean : {jnp.mean(I_base):-8.2f}') print(f'Var : {jnp.var(I_base):-8.2f}') # >>> ---------------------------------------- # >>> Output: # Time : 54.39 # Mean : 4.65 # Var : 7.59 #
Y = teneva.convert(Y_base) # Convert it to the jax version
t = tpc() interface_rtl = jax.jit(teneva.interface_rtl) sample = jax.jit(jax.vmap(teneva.sample, (None, None, 0))) zl, zm = interface_rtl(Y) rng, key = jax.random.split(rng) I = sample(Y, zm, jax.random.split(key, m)) t = tpc() - t print(f'Time : {t:-8.2f}') print(f'Mean : {jnp.mean(I):-8.2f}') print(f'Var : {jnp.var(I):-8.2f}') # >>> ---------------------------------------- # >>> Output: # Time : 1.64 # Mean : 4.65 # Var : 7.66 #
- teneva_jax.sample.sample_lhs(d, n, m, key)[source]¶
Generate LHS multi-indices for the tensor of the given shape.
- Parameters:
d (int) – number of tensor dimensions.
n (int) – mode size of the tensor.
m (int) – number of samples.
key (jax.random.PRNGKey) – jax random key.
- Returns:
generated multi-indices for the tensor in the form of array of the shape [m, d].
- Return type:
jnp.ndarray
Examples:
d = 3 # Dimension of the tensor/grid n = 5 # Shape of the tensor/grid m = 8 # Number of samples rng, key = jax.random.split(rng) I = teneva.sample_lhs(d, n, m, key) print(I) # >>> ---------------------------------------- # >>> Output: # [[4 1 1] # [3 0 0] # [2 1 0] # [0 3 4] # [1 2 2] # [3 2 3] # [4 4 3] # [0 3 1]] #
- teneva_jax.sample.sample_rand(d, n, m, key)[source]¶
Generate random multi-indices for the tensor of the given shape.
- Parameters:
d (int) – number of tensor dimensions.
n (int) – mode size of the tensor.
m (int) – number of samples.
key (jax.random.PRNGKey) – jax random key.
- Returns:
generated multi-indices for the tensor in the form of array of the shape [m, d].
- Return type:
jnp.ndarray
Examples:
d = 3 # Dimension of the tensor/grid n = 5 # Shape of the tensor/grid m = 8 # Number of samples rng, key = jax.random.split(rng) I = teneva.sample_rand(d, n, m, key) print(I) # >>> ---------------------------------------- # >>> Output: # [[3 1 2] # [3 1 1] # [3 1 1] # [1 2 2] # [4 3 3] # [4 4 1] # [3 0 1] # [2 4 4]] #