Ciro Santilli OurBigBook.com  Sponsor 中国独裁统治 China Dictatorship 新疆改造中心、六四事件、法轮功、郝海东、709大抓捕、2015巴拿马文件 邓家贵、低端人口、西藏骚乱
python/sklearn/knn.py
#!/usr/bin/env python3

# https://cirosantilli.com/scikit-learn
# https://scikit-learn.org/stable/auto_examples/neighbors/plot_classification.html

import os

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
from sklearn import neighbors, datasets
from sklearn.inspection import DecisionBoundaryDisplay

os.makedirs('out', exist_ok=True)
for n_neighbors in range(1, 100, 5):
    # import some data to play with
    iris = datasets.load_iris()

    # we only take the first two features. We could avoid this ugly
    # slicing by using a two-dim dataset
    X = iris.data[:, :2]
    y = iris.target

    # Create color maps
    cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"])
    cmap_bold = ["darkorange", "c", "darkblue"]

    # we create an instance of Neighbours Classifier and fit the data.
    clf = neighbors.KNeighborsClassifier(n_neighbors, weights='uniform')
    clf.fit(X, y)

    _, ax = plt.subplots()
    DecisionBoundaryDisplay.from_estimator(
        clf,
        X,
        cmap=cmap_light,
        ax=ax,
        response_method="predict",
        plot_method="pcolormesh",
        xlabel=iris.feature_names[0],
        ylabel=iris.feature_names[1],
        shading="auto",
    )

    # Plot also the training points
    sns.scatterplot(
        x=X[:, 0],
        y=X[:, 1],
        hue=iris.target_names[y],
        palette=cmap_bold,
        alpha=1.0,
        edgecolor="black",
    )
    plt.title(
        "3-Class classification (k = %i, weights = '%s')" % (n_neighbors, 'uniform')
    )
    plt.savefig(os.path.join('out', 'knn-{:02d}.png'.format(n_neighbors)), bbox_inches='tight')