Source code for jasmine.metrics._classification

import jax
import jax.numpy as jnp


def binary_cross_entropy(y_true, y_pred, from_logits: bool = False):
    """
    Compute the Binary Cross-Entropy loss between true and predicted values.

    Args:
        y_true (jnp.ndarray): Ground truth binary labels (0 or 1).
        y_pred (jnp.ndarray): Predicted probabilities (between 0 and 1).
        from_logits (bool): If True, y_pred is expected to be logits.
            If False, y_pred should be probabilities.

    Returns:
        float: Computed Binary Cross-Entropy loss.
    """
    if from_logits:
        # Numerically stable calculation using jax.nn.log_sigmoid.
        # This is equivalent to applying sigmoid and then binary cross-entropy,
        # but prevents numerical issues with very small or large logits.
        log_p = jax.nn.log_sigmoid(y_pred)
        log_not_p = jax.nn.log_sigmoid(-y_pred)  # log(1 - sigmoid(x)) = log(sigmoid(-x))
        return -jnp.mean(y_true * log_p + (1.0 - y_true) * log_not_p)

    # Standard binary cross-entropy with clipping for stability
    epsilon = 1e-15  # To avoid log(0)
    y_pred = jnp.clip(y_pred, epsilon, 1 - epsilon)
    return -jnp.mean(y_true * jnp.log(y_pred) + (1 - y_true) * jnp.log(1 - y_pred))


[docs] def categorical_cross_entropy(y_true, y_pred, from_logits: bool = False): """ Compute the Categorical Cross-Entropy loss between true and predicted values. Args: y_true (jnp.ndarray): True categorical target values (one-hot encoded). y_pred (jnp.ndarray): Predicted probabilities or logits for each class. from_logits (bool): If True, y_pred is expected to be a raw logit output. Returns: float: Computed Categorical Cross-Entropy loss. """ if from_logits: # Use numerically stable log_softmaxs log_probs = jax.nn.log_softmax(y_pred, axis=-1) else: epsilon = 1e-15 # To avoid log(0) y_pred = jnp.clip(y_pred, epsilon, 1 - epsilon) log_probs = jnp.log(y_pred) return -jnp.mean(jnp.sum(y_true * log_probs, axis=-1))
[docs] @jax.jit def accuracy_score(y_true, y_pred): """ Compute the accuracy score between true and predicted values. Args: y_true (jnp.ndarray): True target values. y_pred (jnp.ndarray): Predicted target values. Returns: float: Computed accuracy score. """ return jnp.mean(y_true == y_pred)
@jax.jit def euclidean_distance(x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray: """ Computes the squared Euclidean distance between a single point and a batch of points. Note: We use squared Euclidean distance because it's computationally cheaper (avoids the square root) and preserves the ranking of distances, which is all that KNN needs to find the nearest neighbors. Args: x1 (jnp.ndarray): A single data point of shape (n_features,). x2 (jnp.ndarray): A batch of data points of shape (n_samples, n_features). Returns: jnp.ndarray: Computed Euclidean distance(n_samples,). """ return jnp.sum((x1 - x2) ** 2, axis=1) @jax.jit def manhattan_distance(x1: jnp.ndarray, x2: jnp.ndarray) -> jnp.ndarray: """ Computes the Manhattan distance between a single point and a batch of points. Args: x1 (jnp.ndarray): A single data point of shape (n_features,). x2 (jnp.ndarray): A batch of data points of shape (n_samples, n_features). Returns: jnp.ndarray: Computed Manhattan distance(n_samples,). """ return jnp.sum(jnp.abs(x1 - x2), axis=1)