Source code for jasmine.datasets._generators

import jax
import jax.numpy as jnp
import os
from typing import Optional, Tuple


[docs] def generate_regression( n_samples=100, n_features=20, n_informative=10, noise=0.0, bias=0.0, shuffle=True, coef=False, random_state=None, ): """ Generate a random regression problem with JAX. This function creates a dataset where the output is a linear combination of a subset of the input features, with optional Gaussian noise. Args: n_samples (int): The number of samples to generate. n_features (int): The total number of features. n_informative (int): The number of features that are actually used to generate the output. The rest are noise. noise (float): The standard deviation of the Gaussian noise added to the output. bias (float): The bias term (intercept) in the underlying linear model. shuffle (bool): Whether to shuffle the features and informative indices. If False, the informative features will always be the first `n_informative` columns. coef (bool): If True, the ground truth coefficients and bias are returned. random_state (int, optional): Seed for the random number generator for reproducibility. If None, a random seed is used. Returns: tuple: By default, returns (X, y). If `coef` is True, returns (X, y, ground_truth_coefficients). """ if n_informative > n_features: raise ValueError( f"n_informative ({n_informative}) cannot be greater than n_features ({n_features})" ) # If no seed is provided, use a secure source of randomness if random_state is None: random_state = int.from_bytes(os.urandom(4), "big") key = jax.random.PRNGKey(random_state) x_key, w_key, noise_key, shuffle_key = jax.random.split(key, 4) X = jax.random.normal(x_key, (n_samples, n_features)) ground_truth = jnp.zeros(n_features) informative_weights = 100 * jax.random.normal(w_key, (n_informative,)) if shuffle: indices = jax.random.permutation(shuffle_key, jnp.arange(n_features)) informative_indices = indices[:n_informative] ground_truth = ground_truth.at[informative_indices].set(informative_weights) else: ground_truth = ground_truth.at[:n_informative].set(informative_weights) y = X @ ground_truth + bias if noise > 0: y += jax.random.normal(noise_key, (n_samples,)) * noise if coef: return X, y, ground_truth, bias else: return X, y
[docs] def generate_polynomial( n_samples: int = 100, degree: int = 2, noise: float = 0.0, bias: float = 0.0, coef: bool = False, random_state: Optional[int] = None, ): """ Generate a polynomial regression problem with one feature. Args: n_samples: The number of samples. degree: The degree of the polynomial relationship. noise: The standard deviation of the Gaussian noise. bias: The bias term (intercept). coef: If True, the ground truth coefficients and bias are returned. random_state: Seed for the random number generator. Returns: By default, returns (X, y). X will have a shape of (n_samples, 1). If `coef` is True, returns (X, y, ground_truth_coefficients, bias). """ if random_state is None: random_state = int.from_bytes(os.urandom(4), "big") key = jax.random.PRNGKey(random_state) x_key, w_key, noise_key = jax.random.split(key, 3) # Generate a single feature, sorted for easy plotting X = jax.random.uniform(x_key, (n_samples, 1), minval=-5, maxval=5) X = jnp.sort(X, axis=0) # Generate true coefficients for the polynomial terms (x, x^2, ..., x^degree) true_coefficients = jax.random.normal(w_key, (degree,)) * 5 # Create the polynomial features from the original feature X powers = jnp.arange(1, degree + 1) X_poly_features = X**powers # Calculate y using the polynomial equation: y = (w1*x + w2*x^2 + ...) + bias y = X_poly_features @ true_coefficients + bias if noise > 0.0: y += jax.random.normal(noise_key, (n_samples,)) * noise if coef: return X, y, true_coefficients, bias else: return X, y
[docs] def generate_classification( n_samples: int = 100, n_features: int = 20, n_informative: int = 5, n_redundant: int = 2, n_classes: int = 2, class_sep: float = 1.0, feature_noise: float = 1.0, redundant_noise: float = 0.0, shuffle: bool = True, random_state: Optional[int] = None, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Generate a random n-class classification problem with. This function creates clusters of points normally distributed around vertices of a hypercube, making it suitable for testing classification algorithms. Args: n_samples: The number of samples. n_features: The total number of features. n_informative: The number of informative features. n_redundant: The number of redundant features (linear combinations of informative features). n_classes: The number of classes (or labels). class_sep: Factor multiplying the hypercube size. Larger values spread out the classes and make the problem easier. shuffle: Whether to shuffle the features. random_state: Seed for the random number generator. Returns: A tuple (X, y) where X is the feature matrix and y are the integer labels. """ # Validate input parameters if n_samples <= 0: raise ValueError("n_samples must be a positive integer, got {n_samples}") if n_features <= 0: raise ValueError("n_features must be a positive integer, got {n_features}") if n_informative < 0: raise ValueError(f"n_informative must be non-negative, got {n_informative}") if n_redundant < 0: raise ValueError(f"n_redundant must be non-negative, got {n_redundant}") if n_classes < 1: raise ValueError(f"n_classes must be at least 1, got {n_classes}") if class_sep <= 0: raise ValueError(f"class_sep must be positive, got {class_sep}") if feature_noise < 0: raise ValueError(f"feature_noise must be non-negative, got {feature_noise}") if redundant_noise < 0: raise ValueError(f"redundant_noise must be non-negative, got {redundant_noise}") if n_informative + n_redundant > n_features: raise ValueError( f"n_informative ({n_informative}) + n_redundant ({n_redundant}) " f"cannot be greater than n_features ({n_features})" ) if random_state is None: random_state = int.from_bytes(os.urandom(4), "big") main_key = jax.random.PRNGKey(random_state) keys = jax.random.split(main_key, 6) centroid_key, class_key, noise_key, redundant_key, redundant_noise_key, shuffle_key = keys # Generate class centroids - only informative features should be class-separating centroids = jnp.zeros((n_classes, n_features)) if n_informative > 0: # Create informative centroids at hypercube vertices informative_centroids = jax.random.choice( centroid_key, jnp.array([-class_sep, class_sep]), shape=(n_classes, n_informative) ) centroids = centroids.at[:, :n_informative].set(informative_centroids) # Assign samples to classes uniformly at random y = jax.random.randint(class_key, (n_samples,), 0, n_classes) # Initialize feature matrix X = jnp.zeros((n_samples, n_features)) # Create informative features by adding noise to class centroids if n_informative > 0: informative_base = centroids[y][:, :n_informative] if feature_noise > 0: noise = jax.random.normal(noise_key, (n_samples, n_informative)) * feature_noise informative_features = informative_base + noise else: informative_features = informative_base X = X.at[:, :n_informative].set(informative_features) # Create redundant features as linear combinations of informative features if n_redundant > 0 and n_informative > 0: # Generate random weights for linear combinations w_redundant = jax.random.normal(redundant_key, (n_redundant, n_informative)) redundant_features = X[:, :n_informative] @ w_redundant.T # Add noise to redundant features if specified if redundant_noise > 0: redundant_noise_vals = ( jax.random.normal(redundant_noise_key, (n_samples, n_redundant)) * redundant_noise ) redundant_features = redundant_features + redundant_noise_vals X = X.at[:, n_informative : n_informative + n_redundant].set(redundant_features) # Fill remaining features with pure noise n_noise = n_features - n_informative - n_redundant if n_noise > 0: # Use a separate key for noise features noise_key_final = jax.random.fold_in(noise_key, 1) noise_features = jax.random.normal(noise_key_final, (n_samples, n_noise)) X = X.at[:, -n_noise:].set(noise_features) # Shuffle features to avoid position bias if shuffle: # Shuffle along axis=1 (features), not axis=0 (samples) feature_indices = jax.random.permutation(shuffle_key, n_features) X = X[:, feature_indices] return X, y