[k-NN] Practicing k-Nearest Neighbors classification using cross validation with Python

5 minute read

Understanding k-nearest Neighbors algorithm(k-NN)

  • k-NN is one the simplest supervised machine leaning algorithms mostly used for classification, but also for regression.
  • In k-NN classification, the input consists of the k closest training examples in dataset, and the output consists of labels of a class. Here, k is a positive integer.
  • k-NN stores all available cases and classifies new cases based on a similarity measure.
  • k-NN classifies a new data point based on how its neighbors are classified. To be more specific, the new data point is classified by a plurality vote of its neighbors and assigned to the class most common among its k nearest neighbors.

How to decide k ?

  • The optimal k differs from data.

  • When k is too small:

  • There would a concern of overfitting.
  • Outliers would be greatly affected in classifying a data point.
  • The pattern would not be intuitive.

  • When k is too big:

    • There would be a difficulty in classifying data points at the boundaries.
  • When y is categorical:

  • In binary classification problems, it is helpful to choose k to be an odd number as this avoids tied votes.

How to calculate distance?

  • When there are n of training examples (Xi, Yi), i = 1, …, N, they are arranged by following condition:
    • d(X(1), x) ≤ … ≤ d(X(n), x)
  • Regarding distance d(a, b):
    • When independent variables are categorical: Hamming distance
    • When independent variables are continuous: Euclidian distance, Manhattan distance.


[k-NN] Practicing k-Nearest Neighbors classification using cross validation with Python

(1) Importing modules and data

from sklearn import neighbors, datasets
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
iris = datasets.load_iris()
Data information
print(iris.DESCR)
.. _iris_dataset:

Iris plants dataset
--------------------

**Data Set Characteristics:**

    :Number of Instances: 150 (50 in each of three classes)
    :Number of Attributes: 4 numeric, predictive attributes and the class
    :Attribute Information:
        - sepal length in cm
        - sepal width in cm
        - petal length in cm
        - petal width in cm
        - class:
                - Iris-Setosa
                - Iris-Versicolour
                - Iris-Virginica
                
    :Summary Statistics:

    ============== ==== ==== ======= ===== ====================
                    Min  Max   Mean    SD   Class Correlation
    ============== ==== ==== ======= ===== ====================
    sepal length:   4.3  7.9   5.84   0.83    0.7826
    sepal width:    2.0  4.4   3.05   0.43   -0.4194
    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)
    petal width:    0.1  2.5   1.20   0.76    0.9565  (high!)
    ============== ==== ==== ======= ===== ====================

    :Missing Attribute Values: None
    :Class Distribution: 33.3% for each of 3 classes.
    :Creator: R.A. Fisher
    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)
    :Date: July, 1988

The famous Iris database, first used by Sir R.A. Fisher. The dataset is taken
from Fisher's paper. Note that it's the same as in R, but not as in the UCI
Machine Learning Repository, which has two wrong data points.

This is perhaps the best known database to be found in the
pattern recognition literature.  Fisher's paper is a classic in the field and
is referenced frequently to this day.  (See Duda & Hart, for example.)  The
data set contains 3 classes of 50 instances each, where each class refers to a
type of iris plant.  One class is linearly separable from the other 2; the
latter are NOT linearly separable from each other.

.. topic:: References

   - Fisher, R.A. "The use of multiple measurements in taxonomic problems"
     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to
     Mathematical Statistics" (John Wiley, NY, 1950).
   - Duda, R.O., & Hart, P.E. (1973) Pattern Classification and Scene Analysis.
     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.
   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System
     Structure and Classification Rule for Recognition in Partially Exposed
     Environments".  IEEE Transactions on Pattern Analysis and Machine
     Intelligence, Vol. PAMI-2, No. 1, 67-71.
   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions
     on Information Theory, May 1972, 431-433.
   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II
     conceptual clustering system finds 3 classes in the data.
   - Many, many more ...
X = iris.data[:, :2] ## Only taking the first two features
y = iris.target

(2) Finding an optimal ‘k’ using cross-validation

from sklearn.model_selection import cross_val_score
k_range = range(1,100)
k_scores = []

for k in k_range:
    knn = neighbors.KNeighborsClassifier(k)
    scores = cross_val_score(knn, X,y, cv=10, scoring='accuracy')
    k_scores.append(scores.mean())
plt.plot(k_range, k_scores)
plt.xlabel('K-value')
plt.ylabel('Cross-validated accuracy')
plt.show()
  • The accuracy is the highest when K is around 40 to 45 and decrease after that.
  • Let’s set k as 45 and do classification with a distance weighted K-NN.

(3) Distance weighted k-NN classification (comparing with a baseline k-NN)

  • In this case, the baseline k-NN(weights = ‘uniform’) refers that the all neighbors get an equally weighted “vote” about an observation’s class. On the other hand, weights = ‘distance’ refers to weigh each observation’s “vote” by its distance from the observation we are classifying. (https://chrisalbon.com/machine_learning/nearest_neighbors/k-nearest_neighbors_classifer/)
n_neighbors = 45

h = .02  # step size in the mesh

cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF'])

for weights in ['uniform', 'distance']:
    clf = neighbors.KNeighborsClassifier(n_neighbors, weights=weights)
    clf.fit(X, y)

    # Plot the decision boundary. For that, we will assign a color to each
    # point in the mesh [x_min, x_max]x[y_min, y_max].
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

    # Put the result into a color plot
    Z = Z.reshape(xx.shape)
    plt.figure()
    plt.pcolormesh(xx, yy, Z, cmap=cmap_light)

    # Plot also the training points
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold,
                edgecolor='k', s=20)
    plt.xlim(xx.min(), xx.max())
    plt.ylim(yy.min(), yy.max())
    plt.title("3-Class classification (k = %i, weights = '%s')"
              % (n_neighbors, weights))

plt.show()
<ipython-input-9-546f7d4814f3>:23: MatplotlibDeprecationWarning: shading='flat' when X and Y have the same dimensions as C is deprecated since 3.3.  Either specify the corners of the quadrilaterals with X and Y, or pass shading='auto', 'nearest' or 'gouraud', or set rcParams['pcolor.shading'].  This will become an error two minor releases later.
  plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
<ipython-input-9-546f7d4814f3>:23: MatplotlibDeprecationWarning: shading='flat' when X and Y have the same dimensions as C is deprecated since 3.3.  Either specify the corners of the quadrilaterals with X and Y, or pass shading='auto', 'nearest' or 'gouraud', or set rcParams['pcolor.shading'].  This will become an error two minor releases later.
  plt.pcolormesh(xx, yy, Z, cmap=cmap_light)
  • Compared to the baseline k-NN(weights= ‘uniform’), the distance weighted k-NN shows a clean, intuitive boundary.

(4) Comparing confusion matrices

from sklearn.metrics import confusion_matrix
clf = neighbors.KNeighborsClassifier(n_neighbors=45,weights='uniform')
clf.fit(X,y)
y_pred=clf.predict(X)
confusion_matrix(y,y_pred)
array([[50,  0,  0],
       [ 0, 35, 15],
       [ 1, 10, 39]], dtype=int64)
clf2 = neighbors.KNeighborsClassifier(n_neighbors=45,weights='distance')
clf2.fit(X,y)
y_pred2=clf2.predict(X)
confusion_matrix(y,y_pred2)
array([[50,  0,  0],
       [ 0, 49,  1],
       [ 0, 10, 40]], dtype=int64)
  • The distance weighted k-NN has better accuracy than baseline k-NN


More to read

Reference

Updated: