Source code for jasmine.model_selection._split

import jax
import jax.numpy as jnp


[docs] def train_test_split(X, y, test_size=0.2, shuffle=True, random_state=None): """ Split arrays into random train and test subsets. Args: X (jnp.ndarray): Input features. y (jnp.ndarray): Target labels. test_size (float): Proportion of the dataset to include in the test split (0.0 to 1.0). shuffle (bool): Whether to shuffle the data before splitting. random_state (int, optional): Random seed for reproducibility. Returns: Tuple: X_train, X_test, y_train, y_test """ if not (0.0 < test_size < 1.0): raise ValueError("test_size must be between 0.0 and 1.0") n_samples = X.shape[0] n_test = max(1, int(n_samples * test_size)) # Ensure at least 1 test sample n_train = n_samples - n_test indices = jnp.arange(n_samples) if shuffle: key = ( jax.random.PRNGKey(random_state) if random_state is not None else jax.random.PRNGKey(0) ) indices = jax.random.permutation(key, indices) train_indices = indices[:n_train] test_indices = indices[n_train:] X_train = X[train_indices] X_test = X[test_indices] y_train = y[train_indices] y_test = y[test_indices] return X_train, X_test, y_train, y_test