Inline simple non-virtual cluster API methods
[xapian.git] / xapian-core / cluster / kmeans.cc
blob6af2a13e21a5aaa7b466dac4cd14d96552e24935
1 /** @file
2 * @brief KMeans clustering API
3 */
4 /* Copyright (C) 2016 Richhiey Thomas
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU General Public License as
8 * published by the Free Software Foundation; either version 2 of the
9 * License, or (at your option) any later version.
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
19 * USA
22 #include <config.h>
24 #include "xapian/cluster.h"
25 #include "xapian/error.h"
27 #include "debuglog.h"
29 #include <limits>
30 #include <vector>
32 // Threshold value for checking convergence in KMeans
33 #define CONVERGENCE_THRESHOLD 0.0000000001
35 /** Maximum number of times KMeans algorithm will iterate
36 * till it converges
38 #define MAX_ITERS 1000
40 using namespace Xapian;
41 using namespace std;
43 KMeans::KMeans(unsigned int k_, unsigned int max_iters_)
44 : k(k_)
46 LOGCALL_CTOR(API, "KMeans", k_ | max_iters_);
47 max_iters = (max_iters_ == 0) ? MAX_ITERS : max_iters_;
48 if (k_ == 0)
49 throw InvalidArgumentError("Number of required clusters should be "
50 "greater than zero");
53 string
54 KMeans::get_description() const
56 return "KMeans()";
59 void
60 KMeans::initialise_clusters(ClusterSet& cset, doccount num_of_points)
62 LOGCALL_VOID(API, "KMeans::initialise_clusters", cset | num_of_points);
63 // Initial centroids are selected by picking points at roughly even
64 // intervals within the MSet. This is cheap and helps pick diverse
65 // elements since the MSet is usually sorted by some sort of key
66 for (unsigned int i = 0; i < k; ++i) {
67 unsigned int x = (i * num_of_points) / k;
68 cset.add_cluster(Cluster(Centroid(points[x])));
72 void
73 KMeans::initialise_points(const MSet& source)
75 LOGCALL_VOID(API, "KMeans::initialise_points", source);
76 TermListGroup tlg(source, stopper.get());
77 for (MSetIterator it = source.begin(); it != source.end(); ++it)
78 points.push_back(Point(tlg, it.get_document()));
81 ClusterSet
82 KMeans::cluster(const MSet& mset)
84 LOGCALL(API, ClusterSet, "KMeans::cluster", mset);
85 doccount size = mset.size();
86 if (k >= size)
87 k = size;
88 initialise_points(mset);
89 ClusterSet cset;
90 initialise_clusters(cset, size);
91 CosineDistance distance;
92 vector<Centroid> previous_centroids;
93 for (unsigned int i = 0; i < max_iters; ++i) {
94 // Assign each point to the cluster corresponding to its
95 // closest cluster centroid
96 cset.clear_clusters();
97 for (unsigned int j = 0; j < size; ++j) {
98 double closest_cluster_distance = numeric_limits<double>::max();
99 unsigned int closest_cluster = 0;
100 for (unsigned int c = 0; c < k; ++c) {
101 const Centroid& centroid = cset[c].get_centroid();
102 double dist = distance.similarity(points[j], centroid);
103 if (closest_cluster_distance > dist) {
104 closest_cluster_distance = dist;
105 closest_cluster = c;
108 cset.add_to_cluster(points[j], closest_cluster);
111 // Remember the previous centroids
112 previous_centroids.clear();
113 for (unsigned int j = 0; j < k; ++j)
114 previous_centroids.push_back(cset[j].get_centroid());
116 // Recalculate the centroids for current iteration
117 cset.recalculate_centroids();
119 // Check whether centroids have converged
120 bool has_converged = true;
121 for (unsigned int j = 0; j < k; ++j) {
122 const Centroid& centroid = cset[j].get_centroid();
123 double dist = distance.similarity(previous_centroids[j], centroid);
124 // If distance between any two centroids has changed by
125 // more than the threshold, then KMeans hasn't converged
126 if (dist > CONVERGENCE_THRESHOLD) {
127 has_converged = false;
128 break;
131 // If converged, then break from the loop
132 if (has_converged)
133 break;
135 return cset;