Performance Metrics

The jasmine.metrics module provides performance evaluation metrics.

Regression Metrics

mean_squared_error(y_true, y_pred)

Mean Squared Error loss.

mean_absolute_error(y_true, y_pred)

Compute the Mean Absolute Error (MAE) between true and predicted values.

root_mean_squared_error(y_true, y_pred)

Compute the Root Mean Squared Error (RMSE) between true and predicted values.

r2_score(y_true, y_pred)

Compute the R² score (coefficient of determination) between true and predicted values.

Classification Metrics

accuracy_score(y_true, y_pred)

Compute the accuracy score between true and predicted values.

binary_cross_entropy(y_true, y_pred[, ...])

Binary cross-entropy loss for labels in {0, 1}.

categorical_cross_entropy(y_true, y_pred[, ...])

Compute the Categorical Cross-Entropy loss between true and predicted values.

Regression Functions

mean_squared_error

jasmine.metrics.mean_squared_error(y_true: Array, y_pred: Array) Array

Mean Squared Error loss.

mean_absolute_error

jasmine.metrics.mean_absolute_error(y_true, y_pred)[source]

Compute the Mean Absolute Error (MAE) between true and predicted values.

Parameters:
  • y_true (jnp.ndarray) – True target values.

  • y_pred (jnp.ndarray) – Predicted target values.

Returns:

Computed MAE value.

Return type:

float

root_mean_squared_error

jasmine.metrics.root_mean_squared_error(y_true, y_pred)[source]

Compute the Root Mean Squared Error (RMSE) between true and predicted values.

Parameters:
  • y_true (jnp.ndarray) – True target values.

  • y_pred (jnp.ndarray) – Predicted target values.

Returns:

Computed RMSE value.

Return type:

float

r2_score

jasmine.metrics.r2_score(y_true, y_pred)[source]

Compute the R² score (coefficient of determination) between true and predicted values.

Parameters:
  • y_true (jnp.ndarray) – True target values.

  • y_pred (jnp.ndarray) – Predicted target values.

Returns:

Computed R² score.

Return type:

float

Classification Functions

accuracy_score

jasmine.metrics.accuracy_score(y_true, y_pred)[source]

Compute the accuracy score between true and predicted values.

Parameters:
  • y_true (jnp.ndarray) – True target values.

  • y_pred (jnp.ndarray) – Predicted target values.

Returns:

Computed accuracy score.

Return type:

float

binary_cross_entropy

jasmine.metrics.binary_cross_entropy(y_true: Array, y_pred: Array, from_logits: bool = False, sample_weight: Array | None = None) Array

Binary cross-entropy loss for labels in {0, 1}.

categorical_cross_entropy

jasmine.metrics.categorical_cross_entropy(y_true, y_pred, from_logits: bool = False)[source]

Compute the Categorical Cross-Entropy loss between true and predicted values.

Parameters:
  • y_true (jnp.ndarray) – True categorical target values (one-hot encoded).

  • y_pred (jnp.ndarray) – Predicted probabilities or logits for each class.

  • from_logits (bool) – If True, y_pred is expected to be a raw logit output.

Returns:

Computed Categorical Cross-Entropy loss.

Return type:

float

Examples

Regression Metrics

from jasmine.metrics import (
    mean_squared_error,
    mean_absolute_error,
    r2_score
)
import jax.numpy as jnp

# Sample predictions and targets
y_true = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
y_pred = jnp.array([1.1, 1.9, 3.2, 3.8, 5.1])

# Calculate metrics
mse = mean_squared_error(y_true, y_pred)
mae = mean_absolute_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)

print(f"MSE: {mse:.4f}")
print(f"MAE: {mae:.4f}")
print(f"R²: {r2:.4f}")

Classification Metrics

from jasmine.metrics import accuracy_score, binary_cross_entropy
import jax.numpy as jnp

# Binary classification
y_true = jnp.array([0, 1, 1, 0, 1])
y_pred_classes = jnp.array([0, 1, 1, 0, 0])
y_pred_probs = jnp.array([0.1, 0.9, 0.8, 0.2, 0.4])

# Calculate metrics
accuracy = accuracy_score(y_true, y_pred_classes)
bce = binary_cross_entropy(y_true, y_pred_probs)

print(f"Accuracy: {accuracy:.4f}")
print(f"Binary Cross-Entropy: {bce:.4f}")

Using with Models

from jasmine.linear_model import LinearRegression
from jasmine.metrics import mean_squared_error

# Train model
model = LinearRegression()
model.train(X_train, y_train)

# Evaluate with custom metric
mse = model.evaluate(X_test, y_test, metrics_fn=mean_squared_error)
r2 = model.evaluate(X_test, y_test)  # Default metric

print(f"MSE: {mse:.4f}")
print(f"R²: {r2:.4f}")

Metric Properties

  • All metrics are JIT-compiled for fast computation

  • Functions accept JAX arrays as input

  • Binary cross-entropy supports both probabilities and logits

  • Metrics are designed to work seamlessly with JASMINE models