K-Means Clustering Part 1

Just Another Review of K-Means

K-means clustering is a well-known algorithm for which there already exist many good online resources. My first introduction to K-means (as was the case with many other machine learning techniques) was from the textbook Introduction to Statistical Learning. However, a recent experience caused me to revisit this algorithm. After discussing K-means with someone who held some misconceptions about the algorithm, I decided to dig deeper into some of the details. Though I doubt those misconceptions are widely held, I wanted to address a few of them. Looking more closely at the algorithm, I learned a few new things as well -- in particular, the different initialization techniques. Instead of focusing on a few isolated aspects of K-means, it seemed more natural to wrap them together into a general discussion of the algorithm. And that's how another review of K-means clustering was born. I wish I could say that I'm presenting K-means from an exciting new point of view that will completely redefine the way you think about the algorithm. But the truth is I mostly just had fun making some plots and wanted to share the results.

I decided to split the discussion into several posts. In this first post, we'll talk about why one might use K-means clustering, discuss how to choose the number of clusters, and then give a brief overview of the algorithm. In the second part, we'll discuss different initialization techniques and the importance of running the algorithm several times with different random initializations. Then, in a third post, we'll look a little more closely at the math behind K-means and look at some data that causes trouble for K-means.

The figures in this post were generated using Matplotib. All of the code can be found in this jupyter notebook.

K-Means at a Glance

K-means clustering is an unsupervised learning algorithm which separates data into K distinct groups. As with many other unsupervised methods, the goal is to discover hidden structure within the data. In particular, K-means tries to separate data points into groups where the points in a group are closer to each other than to points in other groups.

In low dimensional data (2 or 3 variables), we might be able to directly identify structure by plotting the data. For higher dimensional data, this cannot often be done nicely. Furthermore, even when we can identify structure in data by plotting it, we often want to describe the structure in a way that we can make further use of it.

For example, in the plot below, we can see that the points cluster nicely into four distinct blobs, but if we want to make use of this clustering, we need a criterion for determining which points are part of each cluster.

Four well-separated clusters of points

We might start by noting that the points in the top-left cluster are distinguished from the rest by the property that \(x < 0\) and \(y > 6\). Similar methods work for the bottom-right cluster, but the \(x\)- and \(y\)-coordinates of the two other clusters overlap.

Thus, we turn to unsupervised learning techniques such as K-means clustering to discover structure in our data.

An Application

A common application of clustering techniques is market segmentation. If an online shopping site has access to information about its customers such as browsing history, purchasing history, age, geographic location, it can be useful to separate its customer base into different subgroups. The purpose of clustering can simply be descriptive — to get an idea of the customer base. For example, suppose the site's customers are predominantly either college students in New England or retirees in Arizona. Looking at the average customer over the whole data set is not very helpful since, in reality, the site has few middle aged customers from Missouri. On the other hand, if the site's customers are first split into two subgroups, the average attributes of each subgroup would be more enligtening. Clustering can also be used to help in supervised learning problems. The site is unlikely to have a huge amount of purchasing data for every single one of its customers. By grouping together similar customers, the site can leverage a lot more data to determine which products a given customer is likely to buy and can target its advertisements accordingly.

Some other applications of K-means are listed here.

An Example: Handwritten Digits

We're going to illustrate the power of K-means through a toy problem, by looking at images of handwritten digits. Specifically, we'll work with the test set of the UCI ML Optical Recognition of Handwritten Digits Data Set which is included in the scikit-learn datasets package. This data set consists of 1797 samples, each corresponding to the pixel intensities of an 8 by 8 pixel grayscale image of a handwritten digit. A few example images are shown below (reconstructed from 8x8 numeric arrays of pixel intensities):

Examples of handwritten digits from the dataset

Toy Problem: Separate the digits into 3 groups of digits having similar shapes (or, more precisely, similar pixel intensity values).

We already have a good system for separating these digits into 10 classes (i.e., by digit), but it's not quite as clear how to separate them into just 3 classes. To try and solve this problem, we can apply K-means clustering with \(K = 3\), and look at some of the digits that made it into each cluster.

Digits separated into three clusters

Without knowing anything about digits, the algorithm has nonetheless picked up on similarities. For example, the last cluster seems to contain most of the 0s, 4s, and 6s.

For a more rigorous analysis, consider the following table which records how many instances of each digit made it into each cluster.

$$ \begin{array}{r|rrrrrrrrrr} & 0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 \\ \hline 1 & 2 & 38 & 161 & 167 & 0 & 72 & 1 & 0 & 65 & 146 \\ 2 & 0 & 140 & 16 & 16 & 55 & 97 & 1 & 179 & 108 & 34 \\ 3 & 176 & 4 & 0 & 0 & 126 & 13 & 179 & 0 & 1 & 0 \\ \end{array} $$

We see that cluster 1 contains most of the 2s, 3s, and 9s, that cluster 2 contains most of the 1s and 7s, and cluster 3 contains most of the 0s, 4s, and 6s and that the 5s and 8s are mostly split between clusters 1 and 2. Another way to see how the clusters split up the data is to use principal component analysis (PCA), a method for representing high-dimensional data in 2-dimensions in such a way as to preserve as much spread as possible.

Clustered digits plotted with the first 2 principal components

Restrictions

K-means clustering cannot be applied to all data sets, and even when it can be applied, it may not give great results:

  • First, the data must be numerical: we need to be able to compute distances between data points and to average data points. There are alternative clustering methods that can deal with categorical or mixed data types, such as K-modes or K-medoids with the Gower metric.
  • The number of clusters K must be set in advance. That is, K-means clustering cannot pick the number of clusters for you. We'll talk about this more in the next section.
  • The goal of K-means is to minimize the distances between points in the same cluster. This captures a specific type of structure, but fails to capture other patterns, and there are numerical data sets that are not well-suited to K-means. We'll discuss this in more detail in a later blog post.

What is K?

The number of clusters, K, is a hyperparameter that must be set in advance. Often, K is chosen using domain expertise. You can also often get an idea for the number of clusters through plotting and data exploration.

Let's explore this further using a real-world data set: Fisher's Iris dataset. The Iris data set one of many included in scikit-learn's datasets package. The Iris data set contains 150 sets of measurements for 3 different species of Iris (50 samples from each species)

One way to examine the data by plotting each variable against every other variable using a pairplot from the seaborn library.

Pair plot of Iris data without species labels

Looking at these plots, we do indeed see some underlying structure. The five pair plots along with the histogram for petal length separate the data into two distinct groups. Recall though that there should be 3 different species of iris represented. What's going on?

Since we have the luxury of knowing the individual labels (i.e. which iris species each point corresponds to), let's go ahead a look at these plots again after incorporating this information.

Pair plot of Iris data with species labels

We see that the data do indeed cluster into three groups, but that two of those just barely overlap.

Another way of plotting high dimensional data is to use a dimensionality reduction technique such as PCA. Below, the Iris data set is plotted using the first two principal components.

Iris data plotted using 2 principal components

So, by exploring the data through visualizations and by using domain expertise, we have come up with two reasonable values of K, namely:

  • \(K = 2\) seems most natural based on the various plots of the data.
  • \(K = 3\) would be an obvious choice if you knew that the data set contained specimens from three different species, but didn't have the species labels included in the data.

The results of performing K-means clustering with \(K = 2\) and \(K = 3\) are presented below.

K-means clustering on Iris data with K = 2 and K = 3

Note that neither one of these solutions perfectly separates the data (either into the 3 distinct species or into setosa/non-setosa), but they come pretty close.

In the absence of domain-expertise, there are a variety of methods one can use to select the number of clusters.

A Look at the Algorithm

For a fine algorithm that gleans
Hidden groupings: two, four, maybe teens
Build centroid and label
Keep up 'til it's stable
That's a very rough draft of K-means

K-means clustering is performed by iterating two steps:

  1. Use the cluster labels to recompute the cluster centroids (geometric centers).
  2. Use the cluster centroids to recompute the cluster labels.

Each iteration of these steps reduces the within-cluster variation (i.e. readjusts the clusters so that they're more compact) and, eventually, the clusters will settle down to a stable configuration: neither the labels nor the centroids will change. A few steps of this process are illustrated in the figure below (using \(K = 4\)): Let's wrap up this section with small warning: the number of clusters is an optional parameter in the scikit-learn KMeans function. However, this implementation is not cleverly selecting the number of clusters based on properties of the data. Instead it's just using the fixed default value of 8 clusters.

a few clustering steps

On the far left, we start with cluster labels (represented by the colors) which we use to compute the centroids in the second figure which are denoted by stars. This illustrates step 1. Now, notice that the labels don't seem well matched to the centroids: for example, the blue point near (0, 5) is much closer to the red centroid than the blue one.

So, going from the second figure to the third, we apply step 2 and relabel each data point so that it matches the nearest centroid. Now the labels give us better looking clusters, but the old centroids are no longer centrally located. The orange star in the third figure is too far to the right; the blue star is too low. We can fix this by applying step 1 again to get the figure on the far right.

In this example, further repetition of these steps won't change anything: we've achieved a stable clustering. However, in general you might need to apply steps 1 and 2 many times before getting your final clusters (and we'll see examples of this later on).

What's Next?

You may have noticed a problem with the description of K-means clustering given above: To perform step 1, we need to already have cluster labels. We can get labels from step 2, but only if we already have cluster centroids. Thus, before iterating these steps, we need to either initialize labels or initialize centroids.

There are various methods for initializing the labels or the centroids. The reason you might want to consider different initialization techniques is because, even though K-means is guaranteed to converge to a stable clustering, the clustering might not be optimal. That is, different initializations can lead to different outcomes. To avoid having to restart too many times, we want initialize in a clever way. In part 2 of this series, we'll discuss several initialization methods and discuss the importance of multiple random initializations in more detail.

social