jasmine.svm.SVMClassifier
- class jasmine.svm.SVMClassifier(C: float = 1.0, learning_rate: float = 0.01, n_epochs: int = 1000, class_weight: ~typing.Dict | None = None, loss_function: ~typing.Callable = <PjitFunction of <function hinge_loss>>)[source]
A linear Support Vector Machine (SVM) classifier.
This model uses the Hinge Loss and gradient descent for optimization and includes support for class weighting and early stopping.
- __init__(C: float = 1.0, learning_rate: float = 0.01, n_epochs: int = 1000, class_weight: ~typing.Dict | None = None, loss_function: ~typing.Callable = <PjitFunction of <function hinge_loss>>)[source]
- Parameters:
C – Regularization parameter. The strength of the regularization is inversely proportional to C. Must be strictly positive.
learning_rate – The step size for gradient descent.
max_iter – The maximum number of passes over the training data.
class_weight – Weights associated with classes in the form {class_label: weight}.
loss_function – The loss function to use. Defaults to hinge_loss.
- init_params(n_features: int, key: PRNGKey | None = None)[source]
Initialize model parameters.
- Parameters:
n_features (int) – Number of features in the input data.
key (jax.random.PRNGKey, optional) – Random key for parameter initialization.
- static forward(params: dict, X: Array) Array[source]
Forward pass for the SVM classifier.
- Parameters:
params (dict) – Model parameters
X (jnp.ndarray) – Input features
- Returns:
Predicted values
- Return type:
jnp.ndarray
- train(X: Array, y: Array, validation_data: Tuple | None = None, early_stopping_patience: int | None = None, verbose: int = 1)[source]
Train the SVM classifier.
Important: The labels y must be transformed to {-1, 1}.
- Parameters:
- Returns:
Fitted model parameters
- Return type:
- inference(X: Array) Array[source]
Make predictions using the trained model.
- Parameters:
X (jnp.ndarray) – Input features
- Returns:
Predicted values
- Return type:
jnp.ndarray
- evaluate(X: ~jax.Array, y: ~jax.Array, metrics_fn=<PjitFunction of <function accuracy_score>>) float[source]
Evaluate the model using the specified metrics function.
- Parameters:
X (jnp.ndarray) – Input features
y (jnp.ndarray) – True labels
metrics_fn (callable) – Metrics function to compute the score
- Returns:
Computed metrics score
- Return type: