Merge branch 'release-0.11.0'
[tor-bridgedb.git] / bridgedb / Storage.py
blob4ae3c9e6e0fdc7f64657b778b0b42ebf51849d71
1 # BridgeDB by Nick Mathewson.
2 # Copyright (c) 2007-2009, The Tor Project, Inc.
3 # See LICENSE for licensing information
5 import calendar
6 import logging
7 import binascii
8 import sqlite3
9 import time
10 import hashlib
11 from functools import wraps
12 from ipaddr import IPAddress
13 from contextlib import contextmanager
14 import sys
15 import datetime
17 from bridgedb.Stability import BridgeHistory
18 import threading
20 toHex = binascii.b2a_hex
21 fromHex = binascii.a2b_hex
22 HEX_ID_LEN = 40
23 BRIDGE_REACHABLE, BRIDGE_BLOCKED = 0, 1
25 def _escapeValue(v):
26 return "'%s'" % v.replace("'", "''")
28 def timeToStr(t):
29 return time.strftime("%Y-%m-%d %H:%M", time.gmtime(t))
30 def strToTime(t):
31 return calendar.timegm(time.strptime(t, "%Y-%m-%d %H:%M"))
33 # The old DB system was just a key->value mapping DB, with special key
34 # prefixes to indicate which database they fell into.
36 # sp|<ID> -- given to bridgesplitter; maps bridgeID to ring name.
37 # em|<emailaddr> -- given to emailbaseddistributor; maps email address
38 # to concatenated ID.
39 # fs|<ID> -- Given to BridgeTracker, maps to time when a router was
40 # first seen (YYYY-MM-DD HH:MM)
41 # ls|<ID> -- given to bridgetracker, maps to time when a router was
42 # last seen (YYYY-MM-DD HH:MM)
44 # We no longer want to use em| at all, since we're not doing that kind
45 # of persistence any more.
47 # Here is the SQL schema.
48 SCHEMA2_SCRIPT = """
49 CREATE TABLE Config (
50 key PRIMARY KEY NOT NULL,
51 value
54 CREATE TABLE Bridges (
55 id INTEGER PRIMARY KEY NOT NULL,
56 hex_key,
57 address,
58 or_port,
59 distributor,
60 first_seen,
61 last_seen
64 CREATE UNIQUE INDEX BridgesKeyIndex ON Bridges ( hex_key );
66 CREATE TABLE EmailedBridges (
67 email PRIMARY KEY NOT NULL,
68 when_mailed
71 CREATE INDEX EmailedBridgesWhenMailed on EmailedBridges ( email );
73 CREATE TABLE BridgeMeasurements (
74 id INTEGER PRIMARY KEY NOT NULL,
75 hex_key,
76 bridge_type,
77 address,
78 port,
79 blocking_country,
80 blocking_asn,
81 measured_by,
82 last_measured,
83 verdict INTEGER
86 CREATE INDEX BlockedBridgesBlockingCountry on BridgeMeasurements(hex_key);
88 CREATE TABLE WarnedEmails (
89 email PRIMARY KEY NOT NULL,
90 when_warned
93 CREATE INDEX WarnedEmailsWasWarned on WarnedEmails ( email );
95 INSERT INTO Config VALUES ( 'schema-version', 2 );
96 """
98 SCHEMA_2TO3_SCRIPT = """
99 CREATE TABLE BridgeHistory (
100 fingerprint PRIMARY KEY NOT NULL,
101 address,
102 port INT,
103 weightedUptime LONG,
104 weightedTime LONG,
105 weightedRunLength LONG,
106 totalRunWeights DOUBLE,
107 lastSeenWithDifferentAddressAndPort LONG,
108 lastSeenWithThisAddressAndPort LONG,
109 lastDiscountedHistoryValues LONG,
110 lastUpdatedWeightedTime LONG
113 CREATE INDEX BridgeHistoryIndex on BridgeHistory ( fingerprint );
115 INSERT OR REPLACE INTO Config VALUES ( 'schema-version', 3 );
117 SCHEMA3_SCRIPT = SCHEMA2_SCRIPT + SCHEMA_2TO3_SCRIPT
120 class BridgeData(object):
121 """Value class carrying bridge information:
122 hex_key - The unique hex key of the given bridge
123 address - Bridge IP address
124 or_port - Bridge TCP port
125 distributor - The distributor (or pseudo-distributor) through which
126 this bridge is being announced
127 first_seen - When did we first see this bridge online?
128 last_seen - When was the last time we saw this bridge online?
130 def __init__(self, hex_key, address, or_port, distributor="unallocated",
131 first_seen="", last_seen=""):
132 self.hex_key = hex_key
133 self.address = address
134 self.or_port = or_port
135 self.distributor = distributor
136 self.first_seen = first_seen
137 self.last_seen = last_seen
140 class Database(object):
141 def __init__(self, sqlite_fname):
142 self._conn = openDatabase(sqlite_fname)
143 self._cur = self._conn.cursor()
144 self.sqlite_fname = sqlite_fname
146 def commit(self):
147 self._conn.commit()
149 def rollback(self):
150 self._conn.rollback()
152 def close(self):
153 #print "Closing DB"
154 self._cur.close()
155 self._conn.close()
157 def getBridgeDistributor(self, bridge, validRings):
158 """If a ``bridge`` is already in the database, get its distributor.
160 :rtype: None or str
161 :returns: The ``bridge`` distribution method, if one was
162 already assigned, otherwise, returns None.
164 distribution_method = None
165 cur = self._cur
167 cur.execute("SELECT id, distributor FROM Bridges WHERE hex_key = ?",
168 (bridge.fingerprint,))
169 result = cur.fetchone()
171 if result:
172 if result[1] in validRings:
173 distribution_method = result[1]
175 return distribution_method
177 def insertBridgeAndGetRing(self, bridge, setRing, seenAt, validRings,
178 defaultPool="unallocated"):
179 '''Updates info about bridge, setting ring to setRing. Also sets
180 distributor to `defaultPool' if setRing isn't a valid ring.
182 Returns the name of the distributor the bridge is assigned to.
184 cur = self._cur
186 t = timeToStr(seenAt)
187 h = bridge.fingerprint
188 assert len(h) == HEX_ID_LEN
190 # Check if this is currently a valid ring name. If not, move into
191 # default pool.
192 if setRing not in validRings:
193 setRing = defaultPool
195 cur.execute("SELECT id FROM Bridges WHERE hex_key = ?", (h,))
196 v = cur.fetchone()
197 if v is not None:
198 bridgeId = v[0]
199 # Update last_seen, address, port and (possibly) distributor.
200 cur.execute("UPDATE Bridges SET address = ?, or_port = ?, "
201 "distributor = ?, last_seen = ? WHERE id = ?",
202 (str(bridge.address), bridge.orPort, setRing,
203 timeToStr(seenAt), bridgeId))
204 return setRing
205 else:
206 # Insert it.
207 cur.execute("INSERT INTO Bridges (hex_key, address, or_port, "
208 "distributor, first_seen, last_seen) "
209 "VALUES (?, ?, ?, ?, ?, ?)",
210 (h, str(bridge.address), bridge.orPort, setRing, t, t))
211 return setRing
213 def cleanEmailedBridges(self, expireBefore):
214 cur = self._cur
215 t = timeToStr(expireBefore)
216 cur.execute("DELETE FROM EmailedBridges WHERE when_mailed < ?", (t,))
218 def getEmailTime(self, addr):
219 addr = hashlib.sha1(addr.encode('utf-8')).hexdigest()
220 cur = self._cur
221 cur.execute("SELECT when_mailed FROM EmailedBridges WHERE email = ?", (addr,))
222 v = cur.fetchone()
223 if v is None:
224 return None
225 return strToTime(v[0])
227 def setEmailTime(self, addr, whenMailed):
228 addr = hashlib.sha1(addr.encode('utf-8')).hexdigest()
229 cur = self._cur
230 t = timeToStr(whenMailed)
231 cur.execute("INSERT OR REPLACE INTO EmailedBridges "
232 "(email,when_mailed) VALUES (?,?)", (addr, t))
234 def getAllBridges(self):
235 """Return a list of BridgeData value classes of all bridges in the
236 database
238 retBridges = []
239 cur = self._cur
240 cur.execute("SELECT hex_key, address, or_port, distributor, "
241 "first_seen, last_seen FROM Bridges")
242 for b in cur.fetchall():
243 bridge = BridgeData(b[0], b[1], b[2], b[3], b[4], b[5])
244 retBridges.append(bridge)
246 return retBridges
248 def getBlockedBridges(self):
249 """Return a dictionary of bridges that are blocked.
251 :rtype: dict
252 :returns: A dictionary that maps bridge fingerprints (as strings) to a
253 three-tuple that captures its blocking state: (country, address,
254 port).
256 ms = self.__fetchBridgeMeasurements()
257 return getBlockedBridgesFromSql(ms)
259 def __fetchBridgeMeasurements(self):
260 """Return all bridge measurement rows from the last three years.
262 We limit our search to three years for performance reasons because the
263 bridge measurement table keeps growing and therefore slowing down
264 queries.
266 :rtype: list
267 :returns: A list of tuples.
269 cur = self._cur
270 old_year = datetime.datetime.utcnow() - datetime.timedelta(days=365*3)
271 cur.execute("SELECT * FROM BridgeMeasurements WHERE last_measured > "
272 "'%s' ORDER BY blocking_country DESC" %
273 old_year.strftime("%Y-%m-%d"))
274 return cur.fetchall()
276 def getBridgesForDistributor(self, distributor):
277 """Return a list of BridgeData value classes of all bridges in the
278 database that are allocated to distributor 'distributor'
280 retBridges = []
281 cur = self._cur
282 cur.execute("SELECT hex_key, address, or_port, distributor, "
283 "first_seen, last_seen FROM Bridges WHERE "
284 "distributor = ?", (distributor, ))
285 for b in cur.fetchall():
286 bridge = BridgeData(b[0], b[1], b[2], b[3], b[4], b[5])
287 retBridges.append(bridge)
289 return retBridges
291 def updateDistributorForHexKey(self, distributor, hex_key):
292 cur = self._cur
293 cur.execute("UPDATE Bridges SET distributor = ? WHERE hex_key = ?",
294 (distributor, hex_key))
296 def getWarnedEmail(self, addr):
297 addr = hashlib.sha1(addr.encode('utf-8')).hexdigest()
298 cur = self._cur
299 cur.execute("SELECT * FROM WarnedEmails WHERE email = ?", (addr,))
300 v = cur.fetchone()
301 if v is None:
302 return False
303 return True
305 def setWarnedEmail(self, addr, warned=True, whenWarned=time.time()):
306 addr = hashlib.sha1(addr.encode('utf-8')).hexdigest()
307 t = timeToStr(whenWarned)
308 cur = self._cur
309 if warned == True:
310 cur.execute("INSERT INTO WarnedEmails"
311 "(email,when_warned) VALUES (?,?)", (addr, t,))
312 elif warned == False:
313 cur.execute("DELETE FROM WarnedEmails WHERE email = ?", (addr,))
315 def cleanWarnedEmails(self, expireBefore):
316 cur = self._cur
317 t = timeToStr(expireBefore)
319 cur.execute("DELETE FROM WarnedEmails WHERE when_warned < ?", (t,))
321 def updateIntoBridgeHistory(self, bh):
322 cur = self._cur
323 cur.execute("INSERT OR REPLACE INTO BridgeHistory values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
324 (bh.fingerprint, str(bh.ip), bh.port,
325 bh.weightedUptime, bh.weightedTime, bh.weightedRunLength,
326 bh.totalRunWeights, bh.lastSeenWithDifferentAddressAndPort,
327 bh.lastSeenWithThisAddressAndPort, bh.lastDiscountedHistoryValues,
328 bh.lastUpdatedWeightedTime))
329 return bh
331 def delBridgeHistory(self, fp):
332 cur = self._cur
333 cur.execute("DELETE FROM BridgeHistory WHERE fingerprint = ?", (fp,))
335 def getBridgeHistory(self, fp):
336 cur = self._cur
337 cur.execute("SELECT * FROM BridgeHistory WHERE fingerprint = ?", (fp,))
338 h = cur.fetchone()
339 if h is None:
340 return
341 return BridgeHistory(h[0],IPAddress(h[1]),h[2],h[3],h[4],h[5],h[6],h[7],h[8],h[9],h[10])
343 def getAllBridgeHistory(self):
344 cur = self._cur
345 v = cur.execute("SELECT * FROM BridgeHistory")
346 if v is None: return
347 for h in v:
348 yield BridgeHistory(h[0],IPAddress(h[1]),h[2],h[3],h[4],h[5],h[6],h[7],h[8],h[9],h[10])
350 def getBridgesLastUpdatedBefore(self, statusPublicationMillis):
351 cur = self._cur
352 v = cur.execute("SELECT * FROM BridgeHistory WHERE lastUpdatedWeightedTime < ?",
353 (statusPublicationMillis,))
354 if v is None: return
355 for h in v:
356 yield BridgeHistory(h[0],IPAddress(h[1]),h[2],h[3],h[4],h[5],h[6],h[7],h[8],h[9],h[10])
359 def openDatabase(sqlite_file):
360 conn = sqlite3.Connection(sqlite_file)
361 cur = conn.cursor()
362 try:
363 try:
364 cur.execute("SELECT value FROM Config WHERE key = 'schema-version'")
365 val, = cur.fetchone()
366 if val == 2:
367 logging.info("Adding new table BridgeHistory")
368 cur.executescript(SCHEMA_2TO3_SCRIPT)
369 elif val != 3:
370 logging.warn("Unknown schema version %s in database.", val)
371 except sqlite3.OperationalError:
372 logging.warn("No Config table found in DB; creating tables")
373 cur.executescript(SCHEMA3_SCRIPT)
374 conn.commit()
375 finally:
376 cur.close()
377 return conn
380 _DB_FNAME = None
381 _LOCK = None
382 _LOCKED = 0
383 _OPENED_DB = None
384 _REFCOUNT = 0
386 class BridgeMeasurement(object):
387 def __init__(self, id, fingerprint, bridge_type, address, port,
388 country, asn, measured_by, last_measured, verdict):
389 self.fingerprint = fingerprint
390 self.country = country
391 self.address = address
392 self.port = port
393 try:
394 self.date = datetime.datetime.strptime(last_measured, "%Y-%m-%d")
395 except ValueError:
396 logging.error("Could not convert SQL date string '%s' to "
397 "datetime object." % last_measured)
398 self.date = datetime.datetime(1970, 1, 1, 0, 0)
399 self.verdict = verdict
401 def compact(self):
402 return (self.country, self.address, self.port)
404 def __contains__(self, item):
405 return (self.country == item.country and
406 self.address == item.address and
407 self.port == item.port)
409 def newerThan(self, other):
410 return self.date > other.date
412 def conflicts(self, other):
413 return (self.verdict != other.verdict and
414 self.country == other.country and
415 self.address == other.address and
416 self.port == other.port)
418 def getBlockedBridgesFromSql(sql_rows):
419 """Return a dictionary that maps bridge fingerprints to a list of
420 bridges that are known to be blocked somewhere.
422 :param list sql_rows: A list of tuples. Each tuple represents an SQL row.
423 :rtype: dict
424 :returns: A dictionary that maps bridge fingerprints (as strings) to a
425 three-tuple that captures its blocking state: (country, address,
426 port).
428 # Separately keep track of measurements that conclude that a bridge is
429 # blocked or reachable.
430 blocked = {}
431 reachable = {}
433 def _shouldSkip(m1):
434 """Return `True` if we can skip this measurement."""
435 # Use our 'reachable' dictionary if our original measurement says that
436 # a bridge is blocked, and vice versa. The purpose is to process
437 # measurements that are possibly conflicting with the one at hand.
438 d = reachable if m1.verdict == BRIDGE_BLOCKED else blocked
439 maybe_conflicting = d.get(m1.fingerprint, None)
440 if maybe_conflicting is None:
441 # There is no potentially conflicting measurement.
442 return False
444 for m2 in maybe_conflicting:
445 if m1.compact() != m2.compact():
446 continue
447 # Conflicting measurement. If m2 is newer than m1, we believe m2.
448 if m2.newerThan(m1):
449 return True
450 # Conflicting measurement. If m1 is newer than m2, we believe m1,
451 # and remove m1.
452 if m1.newerThan(m2):
453 d[m1.fingerprint].remove(m2)
454 # If we're left with an empty list, get rid of the dictionary
455 # key altogether.
456 if len(d[m1.fingerprint]) == 0:
457 del d[m1.fingerprint]
458 return False
459 return False
461 for fields in sql_rows:
462 m = BridgeMeasurement(*fields)
463 if _shouldSkip(m):
464 continue
466 d = blocked if m.verdict == BRIDGE_BLOCKED else reachable
467 other_measurements = d.get(m.fingerprint, None)
468 if other_measurements is None:
469 # We're dealing with the first "blocked" or "reachable" measurement
470 # for the given bridge fingerprint.
471 d[m.fingerprint] = [m]
472 else:
473 # Do we have an existing measurement that agrees with the given
474 # measurement?
475 if m in other_measurements:
476 d[m.fingerprint] = [m if m.compact() == x.compact() and
477 m.newerThan(x) else x for x in other_measurements]
478 # We're dealing with a new measurement. Add it to the list.
479 else:
480 d[m.fingerprint] = other_measurements + [m]
482 # Compact-ify the measurements in our dictionary.
483 for k, v in blocked.items():
484 blocked[k] = [i.compact() for i in v]
485 return blocked
487 def clearGlobalDB():
488 """Start from scratch.
490 This is currently only used in unit tests.
492 global _DB_FNAME
493 global _LOCK
494 global _LOCKED
495 global _OPENED_DB
497 _DB_FNAME = None
498 _LOCK = None
499 _LOCKED = 0
500 _OPENED_DB = None
501 _REFCOUNT = 0
503 def initializeDBLock():
504 """Create the lock
506 This must be called before the first database query
508 global _LOCK
510 if not _LOCK:
511 _LOCK = threading.RLock()
512 assert _LOCK
514 def setDBFilename(sqlite_fname):
515 global _DB_FNAME
516 _DB_FNAME = sqlite_fname
518 @contextmanager
519 def getDB(block=True):
520 """Generator: Return a usable database handler
522 Always return a :class:`bridgedb.Storage.Database` that is
523 usable within the current thread. If a connection already exists
524 and it was created by the current thread, then return the
525 associated :class:`bridgedb.Storage.Database` instance. Otherwise,
526 create a new instance, blocking until the existing connection
527 is closed, if applicable.
529 Note: This is a blocking call (by default), be careful about
530 deadlocks!
532 :rtype: :class:`bridgedb.Storage.Database`
533 :returns: An instance of :class:`bridgedb.Storage.Database` used to
534 query the database
536 global _DB_FNAME
537 global _LOCK
538 global _LOCKED
539 global _OPENED_DB
540 global _REFCOUNT
542 assert _LOCK
543 try:
544 own_lock = _LOCK.acquire(block)
545 if own_lock:
546 _LOCKED += 1
548 if not _OPENED_DB:
549 assert _REFCOUNT == 0
550 _OPENED_DB = Database(_DB_FNAME)
552 _REFCOUNT += 1
553 yield _OPENED_DB
554 else:
555 yield False
556 finally:
557 assert own_lock
558 try:
559 _REFCOUNT -= 1
560 if _REFCOUNT == 0:
561 _OPENED_DB.close()
562 _OPENED_DB = None
563 finally:
564 _LOCKED -= 1
565 _LOCK.release()
567 def dbIsLocked():
568 return _LOCKED != 0