jasmine.neighbors.KNNClassifier

class jasmine.neighbors.KNNClassifier(n_neighbors: int = 5, metric: ~typing.Callable = <PjitFunction of <function euclidean_distance>>, random_state: int | None = None)[source]

K-Nearest Neighbors Classifier.

Parameters:
  • n_neighbors (int) – Number of neighbors to use for classification.

  • metric (Callable) – Distance metric function to use.

  • random_state (Optional[int]) – Random seed for reproducibility.

__init__(n_neighbors: int = 5, metric: ~typing.Callable = <PjitFunction of <function euclidean_distance>>, random_state: int | None = None)[source]
train(X: Array, y: Array)[source]

Train the KNN classifier model by memorizing the training data.

Parameters:
  • X (jnp.ndarray) – Training features of shape (n_samples, n_features).

  • y (jnp.ndarray) – Training labels of shape (n_samples,).

inference(X_test: Array) Array[source]

Perform inference on the test data.

Parameters:

X_test (jnp.ndarray) – Test features of shape (n_samples, n_features).

Returns:

Predicted labels for the test data.

Return type:

jnp.ndarray

static predict_single(x_test_single: Array, X_train: Array, y_train: Array, n_neighbors: int, n_classes: int, metric: Callable) Array[source]

Predict the label for a single test instance.

Parameters:
  • x_test_single – Single test instance of shape (n_features,).

  • X_train – Training features of shape (n_samples, n_features).

  • y_train – Training labels of shape (n_samples,).

  • n_neighbors – Number of neighbors to consider.

  • n_classes – Number of classes in the dataset.

  • metric – Distance metric function.

Returns:

Predicted label for the test instance (scalar jnp.ndarray).

evaluate(X: ~jax.Array, y: ~jax.Array, metric_fn=<PjitFunction of <function accuracy_score>>) float[source]

Evaluate the model using the specified metrics function.

Parameters:
  • X (jnp.ndarray) – Input features of shape (n_samples, n_features).

  • y (jnp.ndarray) – True labels of shape (n_samples,).

  • metric_fn (callable) – Metrics function to compute the score.

Returns:

Computed metrics score.

Return type:

float