Source code for jasmine.svm._svm

import jax
import jax.numpy as jnp
import os
import time
from typing import Callable, Optional, Dict, Tuple

from jasmine.losses import hinge_loss
from jasmine.metrics import accuracy_score


[docs] class SVMClassifier: """ A linear Support Vector Machine (SVM) classifier. This model uses the Hinge Loss and gradient descent for optimization and includes support for class weighting and early stopping. """
[docs] def __init__( self, C: float = 1.0, learning_rate: float = 0.01, n_epochs: int = 1000, class_weight: Optional[Dict] = None, loss_function: Callable = hinge_loss, ): """ Args: C: Regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive. learning_rate: The step size for gradient descent. max_iter: The maximum number of passes over the training data. class_weight: Weights associated with classes in the form {class_label: weight}. loss_function: The loss function to use. Defaults to hinge_loss. """ if C <= 0: raise ValueError("Regularization parameter C must be positive.") self.C = C self.learning_rate = learning_rate self.n_epochs = n_epochs self.class_weight = class_weight self.loss_function = loss_function self.params = None
[docs] def init_params(self, n_features: int, key: Optional[jax.random.PRNGKey] = None): """ Initialize model parameters. Args: n_features (int): Number of features in the input data. key (jax.random.PRNGKey, optional): Random key for parameter initialization. """ if key is None: random_state = int.from_bytes(os.urandom(4), "big") key = jax.random.PRNGKey(random_state) w_key, _ = jax.random.split(key) params = {"w": jax.random.normal(w_key, (n_features,)), "b": jnp.array(0.0)} return params
[docs] @staticmethod def forward(params: dict, X: jnp.ndarray) -> jnp.ndarray: """ Forward pass for the SVM classifier. Args: params (dict): Model parameters X (jnp.ndarray): Input features Returns: jnp.ndarray: Predicted values """ return X @ params["w"] + params["b"]
[docs] def loss_fn(self, params: dict, X: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """ Compute the loss for the SVM classifier. """ scores = self.forward(params, X) sample_weights = None if self.class_weight: # Create sample weights based on the class_weight dictionary weight_neg = self.class_weight.get(-1, 1.0) weight_pos = self.class_weight.get(1, 1.0) sample_weights = jnp.where(y == 1, weight_pos, weight_neg) # Calculate the data loss using the specified loss function, scaled by C data_loss = self.C * self.loss_function(y, scores, sample_weights) # Add L2 penalty (weight decay) reg_loss = 0.5 * jnp.sum(params["w"] ** 2) return data_loss + reg_loss
[docs] def train( self, X: jnp.ndarray, y: jnp.ndarray, validation_data: Optional[Tuple] = None, early_stopping_patience: Optional[int] = None, verbose: int = 1, ): """ Train the SVM classifier. Important: The labels `y` must be transformed to {-1, 1}. Args: X (jnp.ndarray): Input features y (jnp.ndarray): Target labels validation_data (tuple, optional): Tuple of (X_val, y_val) for validation early_stopping_patience (int, optional): Number of epochs with no improvement to wait before stopping verbose (int): Verbosity level Returns: dict: Fitted model parameters """ if not jnp.all((y == 1) | (y == -1)): raise ValueError("Labels must be in the set {-1, 1}.") current_params = self.init_params(X.shape[1]) history: Dict[str, list[float]] = {"loss": [], "val_loss": []} best_val_loss = float("inf") epochs_no_improve = 0 best_params = None @jax.jit def update_step(params, X, y): grads = jax.grad(self.loss_fn)(params, X, y) return jax.tree_util.tree_map(lambda p, g: p - self.learning_rate * g, params, grads) start_time = time.time() for epoch in range(self.n_epochs): current_params = update_step(current_params, X, y) train_loss = self.loss_fn(current_params, X, y) history["loss"].append(train_loss) log_msg = f"Epoch {epoch + 1}/{self.n_epochs} - Loss: {train_loss:.4f}" if validation_data is not None: X_val, y_val = validation_data val_loss = self.loss_fn(current_params, X_val, y_val) history["val_loss"].append(val_loss) log_msg += f" - Val Loss: {val_loss:.4f}" if early_stopping_patience is not None: if val_loss < best_val_loss: best_val_loss = val_loss epochs_no_improve = 0 best_params = current_params else: epochs_no_improve += 1 if epochs_no_improve >= early_stopping_patience: if verbose > 0: print(f"\nEarly stopping triggered after {epoch+1} epochs.") self.params = best_params return history if verbose > 0: print(log_msg, end="\r") if verbose > 0: total_time = time.time() - start_time print(f"\nTraining completed in {total_time:.2f} seconds.") self.params = best_params if best_params is not None else current_params return history
[docs] def inference(self, X: jnp.ndarray) -> jnp.ndarray: """ Make predictions using the trained model. Args: X (jnp.ndarray): Input features Returns: jnp.ndarray: Predicted values """ if self.params is None: raise ValueError( "Model has not been trained yet. Call `train` before calling `inference`." ) scores = self.forward(self.params, X) return jnp.sign(scores).astype(int)
[docs] def evaluate(self, X: jnp.ndarray, y: jnp.ndarray, metrics_fn=accuracy_score) -> float: """ Evaluate the model using the specified metrics function. Args: X (jnp.ndarray): Input features y (jnp.ndarray): True labels metrics_fn (callable): Metrics function to compute the score Returns: float: Computed metrics score """ if self.params is None: raise ValueError( "Model has not been trained yet. Call `train` before calling `evaluate`." ) class_predictions = self.inference(X) return metrics_fn(y, class_predictions)