2 * @brief KMeans clustering API
4 /* Copyright (C) 2016 Richhiey Thomas
6 * This program is free software; you can redistribute it and/or
7 * modify it under the terms of the GNU General Public License as
8 * published by the Free Software Foundation; either version 2 of the
9 * License, or (at your option) any later version.
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301
24 #include "xapian/cluster.h"
25 #include "xapian/error.h"
32 // Threshold value for checking convergence in KMeans
33 #define CONVERGENCE_THRESHOLD 0.0000000001
35 /** Maximum number of times KMeans algorithm will iterate
38 #define MAX_ITERS 1000
40 using namespace Xapian
;
43 KMeans::KMeans(unsigned int k_
, unsigned int max_iters_
)
46 LOGCALL_CTOR(API
, "KMeans", k_
| max_iters_
);
47 max_iters
= (max_iters_
== 0) ? MAX_ITERS
: max_iters_
;
49 throw InvalidArgumentError("Number of required clusters should be "
54 KMeans::get_description() const
60 KMeans::initialise_clusters(ClusterSet
& cset
, doccount num_of_points
)
62 LOGCALL_VOID(API
, "KMeans::initialise_clusters", cset
| num_of_points
);
63 // Initial centroids are selected by picking points at roughly even
64 // intervals within the MSet. This is cheap and helps pick diverse
65 // elements since the MSet is usually sorted by some sort of key
66 for (unsigned int i
= 0; i
< k
; ++i
) {
67 unsigned int x
= (i
* num_of_points
) / k
;
68 cset
.add_cluster(Cluster(Centroid(points
[x
])));
73 KMeans::initialise_points(const MSet
& source
)
75 LOGCALL_VOID(API
, "KMeans::initialise_points", source
);
76 TermListGroup
tlg(source
, stopper
.get());
77 for (MSetIterator it
= source
.begin(); it
!= source
.end(); ++it
)
78 points
.push_back(Point(tlg
, it
.get_document()));
82 KMeans::cluster(const MSet
& mset
)
84 LOGCALL(API
, ClusterSet
, "KMeans::cluster", mset
);
85 doccount size
= mset
.size();
88 initialise_points(mset
);
90 initialise_clusters(cset
, size
);
91 CosineDistance distance
;
92 vector
<Centroid
> previous_centroids
;
93 for (unsigned int i
= 0; i
< max_iters
; ++i
) {
94 // Assign each point to the cluster corresponding to its
95 // closest cluster centroid
96 cset
.clear_clusters();
97 for (unsigned int j
= 0; j
< size
; ++j
) {
98 double closest_cluster_distance
= numeric_limits
<double>::max();
99 unsigned int closest_cluster
= 0;
100 for (unsigned int c
= 0; c
< k
; ++c
) {
101 const Centroid
& centroid
= cset
[c
].get_centroid();
102 double dist
= distance
.similarity(points
[j
], centroid
);
103 if (closest_cluster_distance
> dist
) {
104 closest_cluster_distance
= dist
;
108 cset
.add_to_cluster(points
[j
], closest_cluster
);
111 // Remember the previous centroids
112 previous_centroids
.clear();
113 for (unsigned int j
= 0; j
< k
; ++j
)
114 previous_centroids
.push_back(cset
[j
].get_centroid());
116 // Recalculate the centroids for current iteration
117 cset
.recalculate_centroids();
119 // Check whether centroids have converged
120 bool has_converged
= true;
121 for (unsigned int j
= 0; j
< k
; ++j
) {
122 const Centroid
& centroid
= cset
[j
].get_centroid();
123 double dist
= distance
.similarity(previous_centroids
[j
], centroid
);
124 // If distance between any two centroids has changed by
125 // more than the threshold, then KMeans hasn't converged
126 if (dist
> CONVERGENCE_THRESHOLD
) {
127 has_converged
= false;
131 // If converged, then break from the loop