trac#687: Add unit tests for `redirect` and `redirect_obj`.
[larjonas-mediagoblin.git] / mediagoblin / db / migration_tools.py
blobfae9864347d0a245a1f17b8d1c0813a0ca3f858c
1 # GNU MediaGoblin -- federated, autonomous media hosting
2 # Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as published by
6 # the Free Software Foundation, either version 3 of the License, or
7 # (at your option) any later version.
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU Affero General Public License for more details.
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
17 from __future__ import unicode_literals
19 import logging
20 import os
22 from alembic import command
23 from alembic.config import Config
24 from alembic.migration import MigrationContext
26 from mediagoblin.db.base import Base
27 from mediagoblin.tools.common import simple_printer
28 from sqlalchemy import Table
29 from sqlalchemy.sql import select
31 log = logging.getLogger(__name__)
34 class TableAlreadyExists(Exception):
35 pass
38 class AlembicMigrationManager(object):
40 def __init__(self, session):
41 root_dir = os.path.abspath(os.path.dirname(os.path.dirname(
42 os.path.dirname(__file__))))
43 alembic_cfg_path = os.path.join(root_dir, 'alembic.ini')
44 self.alembic_cfg = Config(alembic_cfg_path)
45 self.session = session
47 def get_current_revision(self):
48 context = MigrationContext.configure(self.session.bind)
49 return context.get_current_revision()
51 def upgrade(self, version):
52 return command.upgrade(self.alembic_cfg, version or 'head')
54 def downgrade(self, version):
55 if isinstance(version, int) or version is None or version.isdigit():
56 version = 'base'
57 return command.downgrade(self.alembic_cfg, version)
59 def stamp(self, revision):
60 return command.stamp(self.alembic_cfg, revision=revision)
62 def init_tables(self):
63 Base.metadata.create_all(self.session.bind)
64 # load the Alembic configuration and generate the
65 # version table, "stamping" it with the most recent rev:
66 # XXX: we need to find a better way to detect current installations
67 # using sqlalchemy-migrate because we don't have to create all table
68 # for them
69 command.stamp(self.alembic_cfg, 'head')
71 def init_or_migrate(self, version=None):
72 # XXX: we need to call this method when we ditch
73 # sqlalchemy-migrate entirely
74 # if self.get_current_revision() is None:
75 # self.init_tables()
76 self.upgrade(version)
79 class MigrationManager(object):
80 """
81 Migration handling tool.
83 Takes information about a database, lets you update the database
84 to the latest migrations, etc.
85 """
87 def __init__(self, name, models, foundations, migration_registry, session,
88 printer=simple_printer):
89 """
90 Args:
91 - name: identifier of this section of the database
92 - session: session we're going to migrate
93 - migration_registry: where we should find all migrations to
94 run
95 """
96 self.name = name
97 self.models = models
98 self.foundations = foundations
99 self.session = session
100 self.migration_registry = migration_registry
101 self._sorted_migrations = None
102 self.printer = printer
104 # For convenience
105 from mediagoblin.db.models import MigrationData
107 self.migration_model = MigrationData
108 self.migration_table = MigrationData.__table__
110 @property
111 def sorted_migrations(self):
113 Sort migrations if necessary and store in self._sorted_migrations
115 if not self._sorted_migrations:
116 self._sorted_migrations = sorted(
117 self.migration_registry.items(),
118 # sort on the key... the migration number
119 key=lambda migration_tuple: migration_tuple[0])
121 return self._sorted_migrations
123 @property
124 def migration_data(self):
126 Get the migration row associated with this object, if any.
128 return self.session.query(
129 self.migration_model).filter_by(name=self.name).first()
131 @property
132 def latest_migration(self):
134 Return a migration number for the latest migration, or 0 if
135 there are no migrations.
137 if self.sorted_migrations:
138 return self.sorted_migrations[-1][0]
139 else:
140 # If no migrations have been set, we start at 0.
141 return 0
143 @property
144 def database_current_migration(self):
146 Return the current migration in the database.
148 # If the table doesn't even exist, return None.
149 if not self.migration_table.exists(self.session.bind):
150 return None
152 # Also return None if self.migration_data is None.
153 if self.migration_data is None:
154 return None
156 return self.migration_data.version
158 def set_current_migration(self, migration_number=None):
160 Set the migration in the database to migration_number
161 (or, the latest available)
163 self.migration_data.version = migration_number or self.latest_migration
164 self.session.commit()
166 def migrations_to_run(self):
168 Get a list of migrations to run still, if any.
170 Note that this will fail if there's no migration record for
171 this class!
173 assert self.database_current_migration is not None
175 db_current_migration = self.database_current_migration
177 return [
178 (migration_number, migration_func)
179 for migration_number, migration_func in self.sorted_migrations
180 if migration_number > db_current_migration]
183 def init_tables(self):
185 Create all tables relative to this package
187 # sanity check before we proceed, none of these should be created
188 for model in self.models:
189 # Maybe in the future just print out a "Yikes!" or something?
190 if model.__table__.exists(self.session.bind):
191 raise TableAlreadyExists(
192 u"Intended to create table '%s' but it already exists" %
193 model.__table__.name)
195 self.migration_model.metadata.create_all(
196 self.session.bind,
197 tables=[model.__table__ for model in self.models])
199 def populate_table_foundations(self):
201 Create the table foundations (default rows) as layed out in FOUNDATIONS
202 in mediagoblin.db.models
204 for Model, rows in self.foundations.items():
205 self.printer(u' + Laying foundations for %s table\n' %
206 (Model.__name__))
207 for parameters in rows:
208 new_row = Model(**parameters)
209 self.session.add(new_row)
211 def create_new_migration_record(self):
213 Create a new migration record for this migration set
215 migration_record = self.migration_model(
216 name=self.name,
217 version=self.latest_migration)
218 self.session.add(migration_record)
219 self.session.commit()
221 def dry_run(self):
223 Print out a dry run of what we would have upgraded.
225 if self.database_current_migration is None:
226 self.printer(
227 u'~> Woulda initialized: %s\n' % self.name_for_printing())
228 return u'inited'
230 migrations_to_run = self.migrations_to_run()
231 if migrations_to_run:
232 self.printer(
233 u'~> Woulda updated %s:\n' % self.name_for_printing())
235 for migration_number, migration_func in migrations_to_run():
236 self.printer(
237 u' + Would update %s, "%s"\n' % (
238 migration_number, migration_func.func_name))
240 return u'migrated'
242 def name_for_printing(self):
243 if self.name == u'__main__':
244 return u"main mediagoblin tables"
245 else:
246 return u'plugin "%s"' % self.name
248 def init_or_migrate(self):
250 Initialize the database or migrate if appropriate.
252 Returns information about whether or not we initialized
253 ('inited'), migrated ('migrated'), or did nothing (None)
255 assure_migrations_table_setup(self.session)
257 # Find out what migration number, if any, this database data is at,
258 # and what the latest is.
259 migration_number = self.database_current_migration
261 # Is this our first time? Is there even a table entry for
262 # this identifier?
263 # If so:
264 # - create all tables
265 # - create record in migrations registry
266 # - print / inform the user
267 # - return 'inited'
268 if migration_number is None:
269 self.printer(u"-> Initializing %s... " % self.name_for_printing())
271 self.init_tables()
272 # auto-set at latest migration number
273 self.create_new_migration_record()
274 self.printer(u"done.\n")
275 self.populate_table_foundations()
276 self.set_current_migration()
277 return u'inited'
279 # Run migrations, if appropriate.
280 migrations_to_run = self.migrations_to_run()
281 if migrations_to_run:
282 self.printer(
283 u'-> Updating %s:\n' % self.name_for_printing())
284 for migration_number, migration_func in migrations_to_run:
285 self.printer(
286 u' + Running migration %s, "%s"... ' % (
287 migration_number, migration_func.__name__))
288 migration_func(self.session)
289 self.set_current_migration(migration_number)
290 self.printer('done.\n')
292 return u'migrated'
294 # Otherwise return None. Well it would do this anyway, but
295 # for clarity... ;)
296 return None
299 class RegisterMigration(object):
301 Tool for registering migrations
303 Call like:
305 @RegisterMigration(33)
306 def update_dwarves(database):
307 [...]
309 This will register your migration with the default migration
310 registry. Alternately, to specify a very specific
311 migration_registry, you can pass in that as the second argument.
313 Note, the number of your migration should NEVER be 0 or less than
314 0. 0 is the default "no migrations" state!
316 def __init__(self, migration_number, migration_registry):
317 assert migration_number > 0, "Migration number must be > 0!"
318 assert migration_number not in migration_registry, \
319 "Duplicate migration numbers detected! That's not allowed!"
321 self.migration_number = migration_number
322 self.migration_registry = migration_registry
324 def __call__(self, migration):
325 self.migration_registry[self.migration_number] = migration
326 return migration
329 def assure_migrations_table_setup(db):
331 Make sure the migrations table is set up in the database.
333 from mediagoblin.db.models import MigrationData
335 if not MigrationData.__table__.exists(db.bind):
336 MigrationData.metadata.create_all(
337 db.bind, tables=[MigrationData.__table__])
340 def inspect_table(metadata, table_name):
341 """Simple helper to get a ref to an already existing table"""
342 return Table(table_name, metadata, autoload=True,
343 autoload_with=metadata.bind)
345 def replace_table_hack(db, old_table, replacement_table):
347 A function to fully replace a current table with a new one for migrati-
348 -ons. This is necessary because some changes are made tricky in some situa-
349 -tion, for example, dropping a boolean column in sqlite is impossible w/o
350 this method
352 :param old_table A ref to the old table, gotten through
353 inspect_table
355 :param replacement_table A ref to the new table, gotten through
356 inspect_table
358 Users are encouraged to sqlalchemy-migrate replace table solutions, unless
359 that is not possible... in which case, this solution works,
360 at least for sqlite.
362 surviving_columns = replacement_table.columns.keys()
363 old_table_name = old_table.name
364 for row in db.execute(select(
365 [column for column in old_table.columns
366 if column.name in surviving_columns])):
368 db.execute(replacement_table.insert().values(**row))
369 db.commit()
371 old_table.drop()
372 db.commit()
374 replacement_table.rename(old_table_name)
375 db.commit()