softcount: tolerate zero ngrams
[vspell.git] / libvspell / warch.cpp
blobf09b5e57ced0a5a7e33daf180802989904aa935b
1 #include "wordnode.h" // -*- tab-width: 2 coding: viscii mode: c++ -*-
2 #include "syllable.h"
3 #include <utility>
4 #include <fstream>
5 #include <iostream>
6 #include <boost/format.hpp>
8 using namespace std;
9 static strid mainleaf_id,caseleaf_id;
10 std::map<strid,LeafNNode*> LeafNNode::leaf_index;
12 LeafNNode* BranchNNode::get_leaf(strid leaf) const
14 node_map::const_iterator iter;
15 iter = nodes.find(leaf);
16 if (iter != nodes.end())
17 return((LeafNNode*)iter->second.get());
18 else
19 return NULL;
22 void BranchNNode::get_leaves(std::vector<LeafNNode*> &_nodes) const
24 const vector<strid> leaf_id = warch.get_leaf_id();
25 node_map::const_iterator iter;
26 uint i,n = leaf_id.size();
27 for (i = 0;i < n;i ++) {
28 iter = nodes.find(leaf_id[i]);
29 if (iter != nodes.end())
30 _nodes.push_back((LeafNNode*)iter->second.get());
34 void BranchNNode::get_branches(strid _id,std::vector<BranchNNode*> &_nodes) const
36 const_np_range range;
37 range = nodes.equal_range(_id);
38 node_map::const_iterator iter;
39 for (iter = range.first;iter != range.second; ++iter)
40 if (!iter->second->is_leaf())
41 _nodes.push_back((BranchNNode*)iter->second.get());
44 BranchNNode* BranchNNode::get_branch(strid _id) const
46 const_np_range range;
47 range = nodes.equal_range(_id);
48 node_map::const_iterator iter;
49 for (iter = range.first;iter != range.second; ++iter)
50 if (!iter->second->is_leaf())
51 return (BranchNNode*)iter->second.get();
52 return NULL;
55 void BranchNNode::add(strid _id,NNodeRef _branch)
57 nodes.insert(make_pair(_id,_branch));
60 BranchNNode* BranchNNode::add_path(const std::vector<strid> &toks)
62 uint i,n = toks.size();
63 BranchNNode *me = this;
64 for (i = 0;i < n;i ++) {
65 BranchNNode *next = me->get_branch(toks[i]);
66 if (next == NULL) {
67 NNodeRef branch(new BranchNNode());
68 me->add(toks[i],branch);
69 next = (BranchNNode*)branch.get();
71 me = next;
73 return me;
76 void WordArchive::init()
78 mainleaf_id = sarch["<mainleaf>"];
79 caseleaf_id = sarch["<caseleaf>"];
80 register_leaf(mainleaf_id);
81 register_leaf(caseleaf_id);
84 bool WordArchive::load(const char* filename)
86 if (filename != NULL) {
87 ifstream ifs(filename);
89 if (!ifs.is_open())
90 return false;
92 string word;
93 while (ifs >> word) {
94 add_entry(word.c_str());
95 add_case_entry(word.c_str());
98 else {
99 const lm_t * lm = get_ngram().get_lm();
100 for (int i = 0;i < lm->ucount;i ++) {
101 add_entry(lm->word_str[i]);
102 add_case_entry(lm->word_str[i]);
105 return true;
108 LeafNNode* WordArchive::add_special_entry(strid tok)
110 return add_entry(sarch[tok]);
112 LeafNNode *leaf = new LeafNNode;
113 NNodeRef noderef(leaf);
114 vector<strid> toks;
115 toks.push_back(tok);
116 leaf->set_id(toks);
117 //leaf->set_mask(MAIN_LEAF);
118 get_root()->add(tok,noderef);
119 return leaf;
123 LeafNNode* WordArchive::add_entry(const char *w)
125 unsigned len,wlen;
126 const char *pos,*start;
127 char *buf;
128 len = strlen(w);
129 start = pos = w;
130 buf = (char *)malloc(len+1);
131 vector<VocabIndex> syllables;
132 while (pos) {
133 pos = strchr(start,'_');
134 wlen = pos ? pos - start : len - (start - w);
135 memcpy(buf,start,wlen);
136 buf[wlen] = '\0';
137 VocabIndex id = sarch[buf];
138 syllables.push_back(id);
139 start = pos+1;
141 free(buf);
143 vector<strid> path = syllables;
144 BranchNNode* branch = get_root()->add_path(path);
145 assert (!branch->get_leaf(mainleaf_id));
146 NNodeRef noderef(new LeafNNode);
147 LeafNNode *leaf = (LeafNNode*)noderef.get();
148 //leaf->set_mask(MAIN_LEAF);
149 branch->add(mainleaf_id,noderef);
150 leaf->set_id(syllables);
151 return leaf;
154 LeafNNode* WordArchive::add_case_entry(const char *w2)
156 unsigned i,same,len,wlen;
157 const char *pos,*start;
158 char *buf;
159 char *w;
160 len = strlen(w2);
161 w = (char *)malloc(len+1);
162 same = 1;
163 for (i = 0;i < len;i ++) {
164 w[i] = (char)viet_tolower(w2[i]);
165 if (same && w[i] != w2[i])
166 same = 0;
168 if (same) {
169 free(w);
170 return NULL;
172 w[len] = '\0';
173 buf = (char *)malloc(len+1);
174 vector<VocabIndex> syllables,real_syllables;
175 start = pos = w;
176 while (pos) {
177 pos = strchr(start,'_');
178 wlen = pos ? pos - start : len - (start - w);
179 memcpy(buf,start,wlen);
180 buf[wlen] = '\0';
181 VocabIndex id = sarch[buf];
182 syllables.push_back(id);
183 start = pos+1;
185 free(w);
187 start = pos = w2;
188 while (pos) {
189 pos = strchr(start,'_');
190 wlen = pos ? pos - start : len - (start - w2);
191 memcpy(buf,start,wlen);
192 buf[wlen] = '\0';
193 VocabIndex id = sarch[buf];
194 real_syllables.push_back(id);
195 start = pos+1;
197 free(buf);
199 vector<strid> path = syllables;
200 BranchNNode* branch = get_root()->add_path(path);
201 assert(!branch->get_leaf(caseleaf_id));
202 NNodeRef noderef(new LeafNNode);
203 LeafNNode *leaf = (LeafNNode*)noderef.get();
204 //leaf->set_mask(CASE_LEAF);
205 branch->add(caseleaf_id,noderef);
206 leaf->set_id(real_syllables);
207 return leaf;
210 void WordArchive::register_leaf(strid id)
212 if (find(leaf_id.begin(),leaf_id.end(),id) == leaf_id.end())
213 leaf_id.push_back(id);
216 void LeafNNode::set_mask(uint maskval,bool mask)
218 if (mask)
219 bitmask |= maskval;
220 else
221 bitmask &= ~maskval;
224 void LeafNNode::set_id(const vector<strid> &_syllables)
226 syllables = _syllables;
227 string word;
228 int i,nr_syllables = syllables.size();
229 for (i = 0;i < nr_syllables;i ++) {
230 if (i)
231 word += "_";
232 word += sarch[syllables[i]];
234 id = sarch[word];
235 leaf_index[id] = this;
238 LeafNNode* LeafNNode::find_leaf(const vector<strid> &syllables)
240 string word;
241 int i,nr_syllables = syllables.size();
242 for (i = 0;i < nr_syllables;i ++) {
243 if (i)
244 word += "_";
245 word += sarch[syllables[i]];
247 strid id = sarch[word];
248 map<strid,LeafNNode*>::iterator iter = leaf_index.find(id);
249 return iter != leaf_index.end() ? iter->second : NULL;
252 std::ostream& operator << (std::ostream &os,const LeafNNode &node)
254 std::vector<strid> syll;
255 node.get_syllables(syll);
256 os << boost::format("%04x %d") % node.bitmask % syll.size();
257 for (std::vector<strid>::size_type i = 0;i < syll.size();i ++) {
258 os << " ";
259 os << sarch[syll[i]];
261 return os;
264 std::istream& operator >> (std::istream &is,LeafNNode* &node)
266 std::vector<strid> syll;
267 int n;
268 uint bitmask;
269 is >> hex >> bitmask >> dec >> n;
270 syll.resize(n);
271 for (std::vector<strid>::size_type i = 0;i < syll.size();i ++) {
272 string s;
273 is >> s;
274 syll[i] = get_ngram()[s];
276 node = LeafNNode::find_leaf(syll);
277 assert(node);
278 return is;