jasmine.svm.SVMClassifier

class jasmine.svm.SVMClassifier(C: float = 1.0, learning_rate: float = 0.01, n_epochs: int = 1000, class_weight: ~typing.Dict | None = None, loss_function: ~typing.Callable = <PjitFunction of <function hinge_loss>>)[source]

A linear Support Vector Machine (SVM) classifier.

This model uses the Hinge Loss and gradient descent for optimization and includes support for class weighting and early stopping.

__init__(C: float = 1.0, learning_rate: float = 0.01, n_epochs: int = 1000, class_weight: ~typing.Dict | None = None, loss_function: ~typing.Callable = <PjitFunction of <function hinge_loss>>)[source]
Parameters:
  • C – Regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive.

  • learning_rate – The step size for gradient descent.

  • max_iter – The maximum number of passes over the training data.

  • class_weight – Weights associated with classes in the form {class_label: weight}.

  • loss_function – The loss function to use. Defaults to hinge_loss.

init_params(n_features: int, key: PRNGKey | None = None)[source]

Initialize model parameters.

Parameters:
  • n_features (int) – Number of features in the input data.

  • key (jax.random.PRNGKey, optional) – Random key for parameter initialization.

static forward(params: dict, X: Array) Array[source]

Forward pass for the SVM classifier.

Parameters:
  • params (dict) – Model parameters

  • X (jnp.ndarray) – Input features

Returns:

Predicted values

Return type:

jnp.ndarray

loss_fn(params: dict, X: Array, y: Array) Array[source]

Compute the loss for the SVM classifier.

train(X: Array, y: Array, validation_data: Tuple | None = None, early_stopping_patience: int | None = None, verbose: int = 1)[source]

Train the SVM classifier.

Important: The labels y must be transformed to {-1, 1}.

Parameters:
  • X (jnp.ndarray) – Input features

  • y (jnp.ndarray) – Target labels

  • validation_data (tuple, optional) – Tuple of (X_val, y_val) for validation

  • early_stopping_patience (int, optional) – Number of epochs with no improvement to wait before stopping

  • verbose (int) – Verbosity level

Returns:

Fitted model parameters

Return type:

dict

inference(X: Array) Array[source]

Make predictions using the trained model.

Parameters:

X (jnp.ndarray) – Input features

Returns:

Predicted values

Return type:

jnp.ndarray

evaluate(X: ~jax.Array, y: ~jax.Array, metrics_fn=<PjitFunction of <function accuracy_score>>) float[source]

Evaluate the model using the specified metrics function.

Parameters:
  • X (jnp.ndarray) – Input features

  • y (jnp.ndarray) – True labels

  • metrics_fn (callable) – Metrics function to compute the score

Returns:

Computed metrics score

Return type:

float