In this blog and next few blogs, I will discuss implementation of k-nearest neighbor (k-NN) algorithm using python libraries. In this blog my objective is to explain in a simple manner how to speed up classification of a data-record by creation of data-structures in K-nearest neighbor algorithm. I will explain two indexing structures K-d tree and ball-tree, both of which are available in python libraries.

**Generic K-Nearest Neighbor algorithm**

K-NN is a classification algorithm and conceptually one of the simplest to understand. It is also called ‘Lazy Learner’ as against ‘Eager Learner’. Most classification algorithms are eager learners; there is a set of training data with example classifications. Training data is used to construct classification model. The model is used for evaluation on test data where data classifications are known. If the evaluated results are satisfactory, the final model is then used for making prediction of classes on data with unknown classifications. Eager learners have, therefore, already done most of their job of model formulation beforehand. A lazy learner, on the other hand, does not build any model beforehand; it waits for the unclassified data and then winds it way through the algorithm to make classification prediction. Lazy learners are, therefore, time consuming–each time a prediction is to be made all the model building effort has to be performed again.

In k-nearest neighbor algorithm, the example data is first plotted on an n-dimensional space where ‘n’ is the number of data-attributes. Each point in ‘n’-dimensional space is labeled with its class value. To discover classification of an unclassified data, the point is plotted on this n-dimensional space and class labels of nearest k data points are noted. Generally k is an odd number. That class which occurs for the maximum number of times among the k nearest data points is taken as the class of the new data-point. That is, decision is by voting of k neighboring points. One of the big advantages of this generic K-Nearest Neighbor algorithm for classification discovery is that it is amenable to parallel operations.

In the generic k-NN model, each time a prediction is to be made for a data point, first this data point’s distance from all other points is to be calculated and then only nearest k-points can be discovered for voting. This approach is also known as brute-force approach. When the volume of data is huge and its dimension is also very large, say, in hundreds or thousands, this repeated distance calculations can be very tedious and time consuming. To fasten up this process and so as to avoid measuring distances from all the points in the data set, some prepossessing of training data is done. This per-processing helps to search points which are likely to be in its neighborhood.

**K-d tree formation — Speeding up K-NN**

One way is to construct a sorted hierarchical data structure called k-d tree or k-dimensional tree. A k-dimensional tree is a binary tree. We illustrate its process of formation below through a working example for easy understanding.

Consider a three dimensional (training) data set shown in Table 0, below left. For convenience of representation we have not shown the fourth column containing class labels for each data record. We have three attributes ‘a’, ‘b’ and ‘c’. Among the three, attribute ‘b’ has the greatest variance. We sort the data set on this attribute (Table 1) and then divide it into two parts at the median.

Table 0 Table 1 Unsorted data Sort on column babcabc22 38 21 6 2 9 4 8 6 4 8 6 2 14 3 2 14 3 8 20 12 8 20 12 10 26 18 10 26 18 12 32 15 12 32 15<--- 18 56 33 22 38 21 16 44 27 16 44 27 20 50 24 20 50 24 14 62 30 18 56 33 6 2 9 14 62 30

The median is at (12,32,15). Dividing Table 1 into two parts at the median gives us two tables, Table 2 and Table 3 as below. Next, from among the remaining (two) attributes, we select that dimension that has the greatest variance. This dimension is ‘c’. Again, we sort the two tables on this dimension and then break them at the respective medians.

BreakTable 1on median (12,32,15) > < Table 2 Table 3abcabc22 38 21 6 2 9 16 44 27 4 8 6 20 50 24 2 14 3 18 56 33 8 20 12 14 62 30 10 26 18

Tables sorted on column C are as below.

Sort Table 2 on column c Sort Table 3 on column c Table 3 Table 4abcabc22 38 21 2 14 3 20 50 24 4 8 6 16 44 27<--- 6 2 9<--- 14 62 30 8 20 12 18 56 33 10 26 18

Table 3 is next split at median, (16,44,27), and Table 4 is split at median, (6,2,9), as below.

Break Table 3 on median (16,44,27) > < 14 62 30 22 38 21 18 56 33 20 50 24 Break Table 4 on median (6,2,9) > < 8 20 12 2 14 3 10 26 18 4 8 6

We have now four tables here. If we decide to end the splitting process then these four tables are tree-leaves. Else, next we would split by sorting column ‘a’ (and next split on ‘b’, ‘c’…).

Once this data structure is created, it is easy to find out the (approx) neighborhood of any point. For example, to find the neighborhood of a point (9,25,16), we move down the hierarchy left or right. First, at root node, we compare 25, with the value at the root (from column b) then at the next node we compare, 16 (from column c) and lastly 9. The data at the leaf are possible (but not necessarily) nearest points. Generally distances from the points in the table on the other side of this node are also calculated in order to discover nearest points. One may also move a step up the tree to discover nearest points. Incidentally, the median points (12,32,15, for example) are also made a part of either left or right sub-tree.

**Ball-tree data structure**

Another data structure to speed up discovery of neighborhood points is ball-tree data-structure. Ball tree data structure is very efficient especially in situations when number of dimensions is very large. A ball tree is also a binary tree with a hierarchical (binary) structure. To start with two clusters (each resembling a ball) are created. As it is a multidimensional space, each ball may be appropriately called a hypersphere. Any point in n-dimensional space will belong to either cluster but not to both. It will belong to the cluster from whose centroid its distance is less. If the distance of this point from the centroids of both the balls is same, it may be included in any one of the clusters. It is possible that both (virtual) hyper spheres may intersect but the points will belong to only one of the two. Next, each of the balls is again subdivided into two sub-clusters, again each resembling a ball; meaning thereby that in these sub-clusters again there are two centroids and membership of the point to a ball is decided based upon its distance from the centroid of the sub-cluster. We again sub-divide each of these sub-sub balls and so on up till certain depth.

An unclassified (target) point must fall within any one of the nested balls. Points within this nested ball are expected to be nearest to target point. Points in other nearby balls (or enveloping balls) may also be nearer to it (for example, this point may be at the boundary of one of the balls.) Nevertheless, one need not calculate the distance of this unclassified point from all the points in the n-dimensional space. This hastens up the classification process. Ball tree formation initially requires a lot of time and memory but once nested hyper-spheres are created and placed in memory discovery of nearest points becomes easier.

In my next blog, I have given examples of how to use classes in sklearn to perform K-NN classification.

Tags: Ball tree explained in simple manner, forming a k-d tree, k-d tree formation illustration, k-d tree practical example, k-d tree tutorial, k-nearest neighbor

February 6, 2017 at 1:26 pm |

Okay, this was indeed explained in a very simple manner, I’m a biologist and I think I got the concept in a single read through, you’re a good teacher 🙂 But doesn’t the method as you describe it only approximate the point at which it is near? Because it may make sense to take the dimension with the most variance at the root of the tree, but what if based on this dimension you end up at points in n-dimensional space that are really close to your query with regards to this axis with the most variance, but are really far when looking at the other axes, so you actually misclassified your point? I see that it’s better than a random classifier, but how accurate is it really?