Accept '/r/foo' everywhere: part 2
[reddit.git] / scripts / inject_test_data.py
blobc57fc6860921076bf3a51a3fa1f6b58b54ae6e76
1 # The contents of this file are subject to the Common Public Attribution
2 # License Version 1.0. (the "License"); you may not use this file except in
3 # compliance with the License. You may obtain a copy of the License at
4 # http://code.reddit.com/LICENSE. The License is based on the Mozilla Public
5 # License Version 1.1, but Sections 14 and 15 have been added to cover use of
6 # software over a computer network and provide for limited attribution for the
7 # Original Developer. In addition, Exhibit A has been modified to be consistent
8 # with Exhibit B.
10 # Software distributed under the License is distributed on an "AS IS" basis,
11 # WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for
12 # the specific language governing rights and limitations under the License.
14 # The Original Code is reddit.
16 # The Original Developer is the Initial Developer. The Initial Developer of
17 # the Original Code is reddit Inc.
19 # All portions of the code written by reddit are Copyright (c) 2006-2015 reddit
20 # Inc. All Rights Reserved.
21 ###############################################################################
23 from __future__ import division
25 import collections
26 import HTMLParser
27 import itertools
28 import random
29 import string
30 import time
32 import requests
34 from pylons import g
36 from r2.lib.db import queries
37 from r2.lib import amqp
38 from r2.lib.utils import weighted_lottery
39 from r2.models import Account, NotFound, register, Subreddit, Link, Comment
42 unescape_htmlentities = HTMLParser.HTMLParser().unescape
45 class TextGenerator(object):
46 """A Markov Chain based text mimicker."""
48 def __init__(self, order=8):
49 self.order = order
50 self.starts = collections.Counter()
51 self.start_lengths = collections.defaultdict(collections.Counter)
52 self.models = [
53 collections.defaultdict(collections.Counter)
54 for i in xrange(self.order)]
56 @staticmethod
57 def _in_groups(input_iterable, n):
58 iterables = itertools.tee(input_iterable, n)
59 for offset, iterable in enumerate(iterables):
60 for _ in xrange(offset):
61 next(iterable, None)
62 return itertools.izip(*iterables)
64 def add_sample(self, sample):
65 """Add a sample to the model of text for this generator."""
67 if len(sample) <= self.order:
68 return
70 start = sample[:self.order]
71 self.starts[start] += 1
72 self.start_lengths[start][len(sample)] += 1
73 for order, model in enumerate(self.models, 1):
74 for chars in self._in_groups(sample, order+1):
75 prefix = "".join(chars[:-1])
76 next_char = chars[-1]
77 model[prefix][next_char] += 1
79 def generate(self):
80 """Generate a string similar to samples previously fed in."""
82 start = weighted_lottery(self.starts)
83 desired_length = weighted_lottery(self.start_lengths[start])
84 desired_length = max(desired_length, self.order)
86 generated = []
87 generated.extend(start)
88 while len(generated) < desired_length:
89 # try each model, from highest order down, til we find
90 # something
91 for order, model in reversed(list(enumerate(self.models, 1))):
92 current_prefix = "".join(generated[-order:])
93 frequencies = model[current_prefix]
94 if frequencies:
95 generated.append(weighted_lottery(frequencies))
96 break
97 else:
98 generated.append(random.choice(string.lowercase))
100 return "".join(generated)
103 def fetch_listing(path, limit=1000, batch_size=100):
104 """Fetch a reddit listing from reddit.com."""
106 session = requests.Session()
107 session.headers.update({
108 "User-Agent": "reddit-test-data-generator/1.0",
111 base_url = "https://api.reddit.com" + path
113 after = None
114 count = 0
115 while count < limit:
116 params = {"limit": batch_size, "count": count}
117 if after:
118 params["after"] = after
120 print "> {}-{}".format(count, count+batch_size)
121 response = session.get(base_url, params=params)
122 response.raise_for_status()
124 listing = response.json["data"]
125 for child in listing["children"]:
126 yield child["data"]
127 count += 1
129 after = listing["after"]
130 if not after:
131 break
133 # obey reddit.com's ratelimits
134 # see: https://github.com/reddit/reddit/wiki/API#rules
135 time.sleep(2)
138 class Modeler(object):
139 def __init__(self):
140 self.usernames = TextGenerator(order=2)
142 def model_subreddit(self, subreddit_name):
143 """Return a model of links and comments in a given subreddit."""
145 subreddit_path = "/r/{}".format(subreddit_name)
146 print ">>>", subreddit_path
148 print ">> Links"
149 titles = TextGenerator(order=5)
150 selfposts = TextGenerator(order=8)
151 link_count = self_count = 0
152 urls = set()
153 for link in fetch_listing(subreddit_path, limit=500):
154 self.usernames.add_sample(link["author"])
155 titles.add_sample(unescape_htmlentities(link["title"]))
156 if link["is_self"]:
157 self_count += 1
158 selfposts.add_sample(unescape_htmlentities(link["selftext"]))
159 else:
160 urls.add(link["url"])
161 link_count += 1
162 self_frequency = self_count / link_count
164 print ">> Comments"
165 comments = TextGenerator(order=8)
166 for comment in fetch_listing(subreddit_path + "/comments"):
167 self.usernames.add_sample(comment["author"])
168 comments.add_sample(unescape_htmlentities(comment["body"]))
170 return SubredditModel(
171 subreddit_name, titles, selfposts, urls, comments, self_frequency)
173 def generate_username(self):
174 """Generate and return a username like those seen on reddit.com."""
175 return self.usernames.generate()
178 class SubredditModel(object):
179 """A snapshot of a subreddit's links and comments."""
181 def __init__(self, name, titles, selfposts, urls, comments, self_frequency):
182 self.name = name
183 self.titles = titles
184 self.selfposts = selfposts
185 self.urls = list(urls)
186 self.comments = comments
187 self.selfpost_frequency = self_frequency
189 def generate_link_title(self):
190 """Generate and return a title like those seen in the subreddit."""
191 return self.titles.generate()
193 def generate_link_url(self):
194 """Generate and return a URL from one seen in the subreddit.
196 The URL returned may be "self" indicating a self post. This should
197 happen with the same frequency it is seen in the modeled subreddit.
200 if random.random() < self.selfpost_frequency:
201 return "self"
202 else:
203 return random.choice(self.urls)
205 def generate_selfpost_body(self):
206 """Generate and return a self-post body like seen in the subreddit."""
207 return self.selfposts.generate()
209 def generate_comment_body(self):
210 """Generate and return a comment body like seen in the subreddit."""
211 return self.comments.generate()
214 def fuzz_number(number):
215 return int(random.betavariate(2, 8) * 5 * number)
218 def ensure_account(name):
219 """Look up or register an account and return it."""
220 try:
221 account = Account._by_name(name)
222 print ">> found /u/{}".format(name)
223 return account
224 except NotFound:
225 print ">> registering /u/{}".format(name)
226 return register(name, "password", "127.0.0.1")
229 def ensure_subreddit(name, author):
230 """Look up or create a subreddit and return it."""
231 try:
232 sr = Subreddit._by_name(name)
233 print ">> found /r/{}".format(name)
234 return sr
235 except NotFound:
236 print ">> creating /r/{}".format(name)
237 sr = Subreddit._new(
238 name=name,
239 title="/r/{}".format(name),
240 author_id=author._id,
241 lang="en",
242 ip="127.0.0.1",
244 sr._commit()
245 return sr
248 def inject_test_data(num_links=25, num_comments=25, num_votes=5):
249 """Flood your reddit install with test data based on reddit.com."""
251 print ">>>> Ensuring configured objects exist"
252 system_user = ensure_account(g.system_user)
253 ensure_account(g.automoderator_account)
254 ensure_subreddit(g.default_sr, system_user)
255 ensure_subreddit(g.takedown_sr, system_user)
257 print
258 print
260 print ">>>> Fetching real data from reddit.com"
261 modeler = Modeler()
262 subreddits = [
263 modeler.model_subreddit("pics"),
264 modeler.model_subreddit("videos"),
265 modeler.model_subreddit("askhistorians"),
267 extra_settings = {
268 "pics": {
269 "show_media": True,
271 "videos": {
272 "show_media": True,
276 print
277 print
279 print ">>>> Generating test data"
280 print ">>> Accounts"
281 account_query = Account._query(sort="_date", limit=500, data=True)
282 accounts = [a for a in account_query if a.name != g.system_user]
283 accounts.extend(
284 ensure_account(modeler.generate_username())
285 for i in xrange(50 - len(accounts)))
287 print ">>> Content"
288 things = []
289 for sr_model in subreddits:
290 sr_author = random.choice(accounts)
291 sr = ensure_subreddit(sr_model.name, sr_author)
293 # make the system user subscribed for easier testing
294 if sr.add_subscriber(system_user):
295 sr._incr("_ups", 1)
297 # apply any custom config we need for this sr
298 for setting, value in extra_settings.get(sr.name, {}).iteritems():
299 setattr(sr, setting, value)
300 sr._commit()
302 for i in xrange(num_links):
303 link_author = random.choice(accounts)
305 link = Link._submit(
306 title=sr_model.generate_link_title(),
307 url=sr_model.generate_link_url(),
308 author=link_author,
309 sr=sr,
310 ip="127.0.0.1",
312 if link.url == "self":
313 link.url = link.make_permalink(sr)
314 link.is_self = True
315 link.selftext = sr_model.generate_selfpost_body()
316 link._commit()
317 queries.queue_vote(link_author, link, dir=True, ip="127.0.0.1")
318 queries.new_link(link)
319 things.append(link)
321 comments = [None]
322 for i in xrange(fuzz_number(num_comments)):
323 comment_author = random.choice(accounts)
324 comment, inbox_rel = Comment._new(
325 comment_author,
326 link,
327 parent=random.choice(comments),
328 body=sr_model.generate_comment_body(),
329 ip="127.0.0.1",
331 queries.queue_vote(
332 comment_author, comment, dir=True, ip="127.0.0.1")
333 queries.new_comment(comment, inbox_rel)
334 comments.append(comment)
335 things.append(comment)
337 for thing in things:
338 for i in xrange(fuzz_number(num_votes)):
339 direction = random.choice([True, None, False])
340 voter = random.choice(accounts)
341 queries.queue_vote(voter, thing, dir=direction, ip="127.0.0.1")
343 amqp.worker.join()