jasmine.neighbors.KNNClassifier
- class jasmine.neighbors.KNNClassifier(n_neighbors: int = 5, metric: ~typing.Callable = <PjitFunction of <function euclidean_distance>>, random_state: int | None = None)[source]
K-Nearest Neighbors Classifier.
- Parameters:
- __init__(n_neighbors: int = 5, metric: ~typing.Callable = <PjitFunction of <function euclidean_distance>>, random_state: int | None = None)[source]
- train(X: Array, y: Array)[source]
Train the KNN classifier model by memorizing the training data.
- Parameters:
X (jnp.ndarray) – Training features of shape (n_samples, n_features).
y (jnp.ndarray) – Training labels of shape (n_samples,).
- inference(X_test: Array) Array[source]
Perform inference on the test data.
- Parameters:
X_test (jnp.ndarray) – Test features of shape (n_samples, n_features).
- Returns:
Predicted labels for the test data.
- Return type:
jnp.ndarray
- static predict_single(x_test_single: Array, X_train: Array, y_train: Array, n_neighbors: int, n_classes: int, metric: Callable) Array[source]
Predict the label for a single test instance.
- Parameters:
x_test_single – Single test instance of shape (n_features,).
X_train – Training features of shape (n_samples, n_features).
y_train – Training labels of shape (n_samples,).
n_neighbors – Number of neighbors to consider.
n_classes – Number of classes in the dataset.
metric – Distance metric function.
- Returns:
Predicted label for the test instance (scalar jnp.ndarray).
- evaluate(X: ~jax.Array, y: ~jax.Array, metric_fn=<PjitFunction of <function accuracy_score>>) float[source]
Evaluate the model using the specified metrics function.
- Parameters:
X (jnp.ndarray) – Input features of shape (n_samples, n_features).
y (jnp.ndarray) – True labels of shape (n_samples,).
metric_fn (callable) – Metrics function to compute the score.
- Returns:
Computed metrics score.
- Return type: