Performance Metrics
The jasmine.metrics module provides performance evaluation metrics.
Regression Metrics
|
Mean Squared Error loss. |
|
Compute the Mean Absolute Error (MAE) between true and predicted values. |
|
Compute the Root Mean Squared Error (RMSE) between true and predicted values. |
|
Compute the R² score (coefficient of determination) between true and predicted values. |
Classification Metrics
|
Compute the accuracy score between true and predicted values. |
|
Binary cross-entropy loss for labels in |
|
Compute the Categorical Cross-Entropy loss between true and predicted values. |
Regression Functions
mean_squared_error
mean_absolute_error
root_mean_squared_error
r2_score
Classification Functions
accuracy_score
binary_cross_entropy
categorical_cross_entropy
- jasmine.metrics.categorical_cross_entropy(y_true, y_pred, from_logits: bool = False)[source]
Compute the Categorical Cross-Entropy loss between true and predicted values.
- Parameters:
y_true (jnp.ndarray) – True categorical target values (one-hot encoded).
y_pred (jnp.ndarray) – Predicted probabilities or logits for each class.
from_logits (bool) – If True, y_pred is expected to be a raw logit output.
- Returns:
Computed Categorical Cross-Entropy loss.
- Return type:
Examples
Regression Metrics
from jasmine.metrics import (
mean_squared_error,
mean_absolute_error,
r2_score
)
import jax.numpy as jnp
# Sample predictions and targets
y_true = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
y_pred = jnp.array([1.1, 1.9, 3.2, 3.8, 5.1])
# Calculate metrics
mse = mean_squared_error(y_true, y_pred)
mae = mean_absolute_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)
print(f"MSE: {mse:.4f}")
print(f"MAE: {mae:.4f}")
print(f"R²: {r2:.4f}")
Classification Metrics
from jasmine.metrics import accuracy_score, binary_cross_entropy
import jax.numpy as jnp
# Binary classification
y_true = jnp.array([0, 1, 1, 0, 1])
y_pred_classes = jnp.array([0, 1, 1, 0, 0])
y_pred_probs = jnp.array([0.1, 0.9, 0.8, 0.2, 0.4])
# Calculate metrics
accuracy = accuracy_score(y_true, y_pred_classes)
bce = binary_cross_entropy(y_true, y_pred_probs)
print(f"Accuracy: {accuracy:.4f}")
print(f"Binary Cross-Entropy: {bce:.4f}")
Using with Models
from jasmine.linear_model import LinearRegression
from jasmine.metrics import mean_squared_error
# Train model
model = LinearRegression()
model.train(X_train, y_train)
# Evaluate with custom metric
mse = model.evaluate(X_test, y_test, metrics_fn=mean_squared_error)
r2 = model.evaluate(X_test, y_test) # Default metric
print(f"MSE: {mse:.4f}")
print(f"R²: {r2:.4f}")
Metric Properties
All metrics are JIT-compiled for fast computation
Functions accept JAX arrays as input
Binary cross-entropy supports both probabilities and logits
Metrics are designed to work seamlessly with JASMINE models