import jax
import jax.numpy as jnp
from typing import Optional, Callable
from jasmine.metrics import accuracy_score, euclidean_distance
[docs]
class KNNClassifier:
"""
K-Nearest Neighbors Classifier.
Args:
n_neighbors (int): Number of neighbors to use for classification.
metric (Callable): Distance metric function to use.
random_state (Optional[int]): Random seed for reproducibility.
"""
[docs]
def __init__(
self,
n_neighbors: int = 5,
metric: Callable = euclidean_distance,
random_state: Optional[int] = None,
):
self.n_neighbors = n_neighbors
self.metric = metric
self.random_state = random_state
self.X_train: Optional[jnp.ndarray] = None
self.y_train: Optional[jnp.ndarray] = None
self.n_classes: Optional[int] = None
[docs]
def train(self, X: jnp.ndarray, y: jnp.ndarray):
"""
Train the KNN classifier model by memorizing the training data.
Args:
X (jnp.ndarray): Training features of shape (n_samples, n_features).
y (jnp.ndarray): Training labels of shape (n_samples,).
"""
self.X_train = X
self.y_train = y
self.n_classes = int(jnp.max(y) + 1)
return self
[docs]
def inference(self, X_test: jnp.ndarray) -> jnp.ndarray:
"""
Perform inference on the test data.
Args:
X_test (jnp.ndarray): Test features of shape (n_samples, n_features).
Returns:
jnp.ndarray: Predicted labels for the test data.
"""
if self.X_train is None or self.y_train is None:
raise ValueError("Model must be trained before inference.")
# Create a per-sample prediction function and vectorize over test samples
def predict_fn(x):
return self.predict_single(
x,
self.X_train,
self.y_train,
self.n_neighbors,
self.n_classes,
self.metric,
)
return jax.vmap(predict_fn)(X_test)
[docs]
@staticmethod
def predict_single(
x_test_single: jnp.ndarray,
X_train: jnp.ndarray,
y_train: jnp.ndarray,
n_neighbors: int,
n_classes: int,
metric: Callable,
) -> jnp.ndarray:
"""
Predict the label for a single test instance.
Args:
x_test_single: Single test instance of shape (n_features,).
X_train: Training features of shape (n_samples, n_features).
y_train: Training labels of shape (n_samples,).
n_neighbors: Number of neighbors to consider.
n_classes: Number of classes in the dataset.
metric: Distance metric function.
Returns:
Predicted label for the test instance (scalar jnp.ndarray).
"""
# Compute distances from the test instance to all training instances.
distances = metric(x_test_single, X_train)
# Get indices of the nearest neighbors
neighbor_indices = jnp.argsort(distances)[:n_neighbors]
# Get the labels of the nearest neighbors
neighbor_labels = y_train[neighbor_indices]
# Vote for the majority class.
votes = jnp.bincount(neighbor_labels, length=n_classes)
# Return the voted class label as a JAX scalar (works under vmap)
return jnp.argmax(votes)
[docs]
def evaluate(self, X: jnp.ndarray, y: jnp.ndarray, metric_fn=accuracy_score) -> float:
"""
Evaluate the model using the specified metrics function.
Args:
X (jnp.ndarray): Input features of shape (n_samples, n_features).
y (jnp.ndarray): True labels of shape (n_samples,).
metric_fn (callable): Metrics function to compute the score.
Returns:
float: Computed metrics score.
"""
if self.X_train is None or self.y_train is None:
raise ValueError("Model must be trained before evaluation.")
predictions = self.inference(X)
return metric_fn(y, predictions)