getting file size for all dict files to be downloaded. coming to be 400mb or so.
[worddb.git] / libs / dmigrations / mysql / generator.py
blobb9870e90f4c69254b441fe90925d4e165e4e2885
1 """
2 Functions that code-generate migrations, used by the ./manage.py dmigration
3 command.
4 """
5 from django.core.management.base import CommandError
6 from django.core.management.color import no_style
7 from django.db import connection
8 from django.db import models
9 from django.conf import settings
10 from dmigrations.generator_utils import save_migration
12 import re
14 def get_commands():
15 return {
16 'app': add_app,
17 'addindex': add_index,
18 'addcolumn': add_column,
19 'addtable': add_table,
20 'new': add_new,
21 'insert': add_insert,
24 def add_app(args, output):
25 " <app>: Add tables for a new application"
26 if len(args) != 1:
27 raise CommandError('./manage.py migration app <name-of-app>')
28 app_label = args[0]
29 app = models.get_app(app_label)
30 from django.core.management.sql import sql_create
32 up_sql = sql_create(app, no_style())
33 down_sql = sql_delete(app, no_style())
35 app_name = app.__name__.replace('.', '_')
36 migration_output = app_mtemplate % (
37 clean_up_create_sql(up_sql), clean_up_create_sql(down_sql)
39 migration_output = migration_code(migration_output)
41 save_migration(output, migration_output, app_name)
43 def add_table(args, output):
44 " <app> <model>: Add tables for a new model"
45 if len(args) != 2:
46 raise CommandError('./manage.py migration app <name-of-app>')
47 app_label, model = args
48 app = models.get_app(app_label)
49 app_name = app.__name__.replace('.', '_')
50 model_to_add = models.get_model(app_label, model)
52 if not model_to_add:
53 raise Exception("Model %s in app %s not found" % (model, app_label))
55 # The following code is a bit of a mess. I copied it from
56 # django.core.management.sql.sql_create without a full understanding of
57 # how it all works. Ideally this needs to refactored in Django itself to
58 # make it easier for libraries such as this one to reuse the table
59 # creation logic.
60 style = no_style()
61 app_models = models.get_models(app)
62 up_sql = []
63 tables = connection.introspection.table_names()
64 known_models = set([
65 model for model in connection.introspection.installed_models(tables)
66 if model not in app_models]
68 pending_references = {}
70 sql_output, references = connection.creation.sql_create_model(
71 model_to_add, style, known_models
73 up_sql.extend(sql_output)
74 for refto, refs in references.items():
75 pending_references.setdefault(refto, []).extend(refs)
76 if refto in known_models:
77 up_sql.extend(
78 connection.creation.sql_for_pending_references(
79 refto, style, pending_references
82 up_sql.extend(
83 connection.creation.sql_for_pending_references(
84 model, style, pending_references
87 # Keep track of the fact that we've created the table for this model.
88 known_models.add(model_to_add)
90 # Create the many-to-many join tables.
91 up_sql.extend(
92 connection.creation.sql_for_many_to_many(model_to_add, style)
94 if not up_sql:
95 raise Exception("Model %s in app %s not found" % (model, app_label))
97 # Down sql just drops any tables we have created
98 down_sql = []
99 for sql in up_sql:
100 if sql.startswith('CREATE TABLE'):
101 down_sql.append('DROP TABLE %s;' % sql.split()[2])
103 # Reverse the order of down_sql
104 down_sql = down_sql[::-1]
106 migration_output = app_mtemplate % (
107 clean_up_create_sql(up_sql), clean_up_create_sql(down_sql)
109 migration_output = migration_code(migration_output)
111 save_migration(output, migration_output, app_name)
113 def add_index(args, output):
114 " <app> <model> <column>: Add an index"
115 if len(args) != 3:
116 raise CommandError(
117 './manage.py migration addindex <app> <model> <column>'
119 app_label, model, column = args
121 migration_output = add_index_mtemplate % (app_label, model, column)
122 migration_output = migration_code(migration_output)
123 save_migration(output, migration_output, 'add_index_%s_%s_%s' % (
124 app_label, model, column
127 def add_column(args, output):
128 " <app> <model> <column> [<column2> ...]: Add one or more columns"
129 if len(args) < 3:
130 raise CommandError(
131 './manage.py migration addcolumn <app> <model> <column> '
132 '[<column2> ...]'
135 app_label, model, columns = args[0], args[1], args[2:]
136 actual_model = models.get_model(app_label, model)
138 style = no_style()
139 sql, references = connection.creation.sql_create_model(
140 actual_model, style, set()
143 col_specs = []
144 for column in columns:
145 is_foreign_key = isinstance(
146 actual_model._meta.get_field_by_name(column)[0], models.ForeignKey
148 col_specs.append((
149 column,
150 extract_column_spec(sql[0], column, is_foreign_key),
151 is_foreign_key
154 migration_defs = [
155 add_column_mtemplate % (app_label, model, column, col_spec)
156 for (column, col_spec, is_foreign_key) in col_specs
157 if not is_foreign_key
159 migration_fk_defs = [
160 add_column_foreignkey_mtemplate % (
161 app_label, model, column, col_spec,
162 actual_model._meta.get_field_by_name(column)[0].rel.to._meta.db_table
164 for (column, col_spec, is_foreign_key) in col_specs
165 if is_foreign_key
167 if migration_fk_defs:
168 print >>sys.stderr, """Warning!
169 You have added columns that are foreign keys (%s).
170 These will be added as nullable. If you need them to be NOT NULL, then you
171 have to write another migration to do that, after you've populated them
172 with data.""" % ','.join([column for (column, x, fk) in col_specs if fk])
174 migration_defs += migration_fk_defs
175 migration_output = migration_code(*migration_defs)
177 if len(columns) == 1:
178 migration_name = 'add_column_%s_to_%s_%s' % (
179 columns[0], app_label, model
181 else:
182 migration_name = 'add_columns_%s_to_%s_%s' % (
183 "_and_".join(columns), app_label, model
186 save_migration(output, migration_output, migration_name)
188 def add_new(args, output):
189 " <description>: Create empty migration (uses description in filename)"
190 if not args:
191 raise CommandError('./manage.py migration new <description>')
193 db_engine = getattr(settings, 'DMIGRATIONS_DATABASE_BACKEND', 'mysql')
195 save_migration(
196 output, skeleton_template % db_engine, '_'.join(args).lower()
199 def add_insert(args, output):
200 " <app> <model>: Create insert migration for data in table"
201 if len(args) != 2:
202 raise CommandError('./manage.py migration insert <app> <model>')
204 app_label, model = args
205 table_name = '%s_%s' % (app_label, model)
207 def get_columns(table_name):
208 "Returns columns for table"
209 cursor = connection.cursor()
210 cursor.execute('describe %s' % table_name)
211 rows = cursor.fetchall()
212 cursor.close()
214 # Sanity check that first column is called 'id' and is primary key
215 first = rows[0]
216 assert first[0] == u'id', 'First column must be id'
217 assert first[3] == u'PRI', 'First column must be primary key'
219 return [r[0] for r in rows]
221 def get_dump(table_name):
222 "Returns {'table_name':..., 'columns':..., 'rows':...}"
223 columns = get_columns(table_name)
224 # Escape column names with `backticks` - so columns with names that
225 # match MySQL reserved words (e.g. "order") don't break things
226 escaped_columns = ['`%s`' % column for column in columns]
227 sql = 'SELECT %s FROM %s' % (', '.join(escaped_columns), table_name)
229 cursor = connection.cursor()
230 cursor.execute(sql)
231 rows = cursor.fetchall()
232 cursor.close()
234 return {
235 'table_name': table_name,
236 'columns': columns,
237 'rows': rows,
240 dump = get_dump(table_name)
242 migration_output = insert_mtemplate % {
243 'table_name': dump['table_name'],
244 'columns': repr(dump['columns']),
245 'insert_rows': pprint.pformat(dump['rows']),
246 'delete_ids': ', '.join(map(str, [r[0] for r in dump['rows']])),
248 migration_output = migration_code(migration_output)
250 save_migration(output, migration_output, 'insert_into_%s_%s' % (
251 app_label, model
254 def sql_delete(app, style):
255 "Returns a list of the DROP TABLE SQL statements for the given app."
256 # This is a modified version of the function in django.core.management.sql
257 # - the original only emits drop table statements for tables that
258 # currently exist in the database, but we want them all regardless
259 from django.db import connection, models
260 from django.db.backends.util import truncate_name
261 from django.contrib.contenttypes import generic
263 table_names = []
264 output = []
266 # Output DROP TABLE statements for standard application tables.
267 to_delete = set()
269 references_to_delete = {}
270 app_models = models.get_models(app)
271 for model in app_models:
272 opts = model._meta
273 for f in opts.local_fields:
274 if f.rel and f.rel.to not in to_delete:
275 references_to_delete.setdefault(f.rel.to, []).append(
276 (model, f)
279 to_delete.add(model)
281 for model in app_models:
282 output.extend(
283 connection.creation.sql_destroy_model(
284 model, references_to_delete, style
288 # Output DROP TABLE statements for many-to-many tables.
289 for model in app_models:
290 opts = model._meta
291 for f in opts.local_many_to_many:
292 output.extend(
293 connection.creation.sql_destroy_many_to_many(model, f, style)
296 return output[::-1] # Reverse it, to deal with table dependencies.
299 def clean_up_create_sql(sqls):
300 "Ensures create table uses correct engine, cleans up whitespace"
302 engine = getattr(settings, 'DMIGRATIONS_MYSQL_ENGINE', 'InnoDB')
304 def neat_format(sql):
305 def indent4(s):
306 lines = s.split('\n')
307 return '\n'.join([' %s' % line for line in lines])
309 bits = ['"""\n%s\n"""' % indent4(bit) for bit in sql]
310 return '[%s]' % ', '.join(bits)
312 def fix_create_table(sql):
313 if sql.strip().startswith("CREATE TABLE"):
314 # Find the last ')'
315 last_index = sql.rindex(')')
316 tail = sql[last_index:]
317 if 'InnoDB' not in tail:
318 tail = tail.replace(
319 ')', ') ENGINE=%s DEFAULT CHARSET=utf8' % engine
321 sql = sql[:last_index] + tail
322 return sql
324 return neat_format(map(fix_create_table, sqls))
326 def extract_column_spec(sql, column, is_foreign_key=False):
327 "Extract column creation spec from a CREATE TABLE statement"
328 lines = sql.split('\n')
329 escaped_column = '`%s`' % column
330 if is_foreign_key: escaped_column = '`%s_id`' % column
331 for line in lines:
332 line = line.strip()
333 if line.startswith(escaped_column):
334 line = line.replace(escaped_column, '')
335 line = line.rstrip(',') # Remove trailing comma
336 return line.strip()
337 assert False, 'Could not find column spec for column %s' % column
340 migration_template = """from dmigrations.%(db_engine)s import migrations as m
341 import datetime
342 migration = %(migration_body)s
345 def migration_code(*migration_defs):
346 db_engine = getattr(settings, 'DMIGRATIONS_DATABASE_BACKEND', 'mysql')
347 if len(migration_defs) == 1:
348 migration_body = migration_defs[0]
349 else:
350 migration_body = (
351 "m.Compound([\n" +
352 "".join([
353 " %s,\n" % m for m in migration_defs
354 ]) +
355 "])\n"
358 return migration_template % {
359 'db_engine': db_engine,
360 'migration_body': migration_body
363 # Templates for code generation
364 add_column_mtemplate = "m.AddColumn('%s', '%s', '%s', '%s')"
365 add_column_foreignkey_mtemplate = "m.AddColumn('%s', '%s', '%s', '%s', '%s')"
367 add_index_mtemplate = "m.AddIndex('%s', '%s', '%s')"
369 app_mtemplate = "m.Migration(sql_up=%s, sql_down=%s)"
371 insert_mtemplate = """m.InsertRows(
372 table_name = '%(table_name)s',
373 columns = %(columns)s,
374 insert_rows = %(insert_rows)s,
375 delete_ids = [%(delete_ids)s]
376 )"""
378 skeleton_template = """from dmigrations.%s import migrations as m
380 class CustomMigration(m.Migration):
381 def __init__(self):
382 sql_up = []
383 sql_down = []
384 super(CustomMigration, self).__init__(
385 sql_up=sql_up, sql_down=sql_down
387 # Or override the up() and down() methods
389 migration = CustomMigration()