Fix #5398 and #5395 - Fix tests failing due to problem creating connection for alembic
[larjonas-mediagoblin.git] / mediagoblin / db / base.py
blob0f17a3a8d36fbe96f25be1240a93de249bce58d1
1 # GNU MediaGoblin -- federated, autonomous media hosting
2 # Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as published by
6 # the Free Software Foundation, either version 3 of the License, or
7 # (at your option) any later version.
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU Affero General Public License for more details.
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 import six
17 import copy
19 from sqlalchemy.ext.declarative import declarative_base
20 from sqlalchemy import inspect
22 from mediagoblin.tools.transition import DISABLE_GLOBALS
24 if not DISABLE_GLOBALS:
25 from sqlalchemy.orm import scoped_session, sessionmaker
26 Session = scoped_session(sessionmaker())
28 class FakeCursor(object):
30 def __init__ (self, cursor, mapper, filter=None):
31 self.cursor = cursor
32 self.mapper = mapper
33 self.filter = filter
35 def count(self):
36 return self.cursor.count()
38 def __copy__(self):
39 # Or whatever the function is named to make
40 # copy.copy happy?
41 return FakeCursor(copy.copy(self.cursor), self.mapper, self.filter)
43 def __iter__(self):
44 return six.moves.filter(self.filter, six.moves.map(self.mapper, self.cursor))
46 def __getitem__(self, key):
47 return self.mapper(self.cursor[key])
49 def slice(self, *args, **kwargs):
50 r = self.cursor.slice(*args, **kwargs)
51 return list(six.moves.filter(self.filter, six.moves.map(self.mapper, r)))
53 class GMGTableBase(object):
54 # Deletion types
55 HARD_DELETE = "hard-deletion"
56 SOFT_DELETE = "soft-deletion"
58 deletion_mode = HARD_DELETE
60 @property
61 def _session(self):
62 return inspect(self).session
64 @property
65 def _app(self):
66 return self._session.bind.app
68 if not DISABLE_GLOBALS:
69 query = Session.query_property()
71 def get(self, key):
72 return getattr(self, key)
74 def setdefault(self, key, defaultvalue):
75 # The key *has* to exist on sql.
76 return getattr(self, key)
78 def save(self, commit=True):
79 sess = self._session
80 if sess is None and not DISABLE_GLOBALS:
81 sess = Session()
82 assert sess is not None, "Can't save, %r has a detached session" % self
83 sess.add(self)
84 if commit:
85 sess.commit()
86 else:
87 sess.flush()
89 def delete(self, commit=True, deletion=None):
90 """ Delete the object either using soft or hard deletion """
91 # Get the setting in the model args if none has been specified.
92 if deletion is None:
93 deletion = self.deletion_mode
95 # If the item is in any collection then it should be removed, this will
96 # cause issues if it isn't. See #5382.
97 # Import here to prevent cyclic imports.
98 from mediagoblin.db.models import CollectionItem, GenericModelReference, \
99 Report, Notification
101 # Some of the models don't have an "id" field which means they can't be
102 # used with GMR, these models won't be in collections because they
103 # can't be. We can skip all of this.
104 if hasattr(self, "id"):
105 # First find the GenericModelReference for this object
106 gmr = GenericModelReference.query.filter_by(
107 obj_pk=self.id,
108 model_type=self.__tablename__
109 ).first()
111 # If there is no gmr then we've got lucky, a GMR is a requirement of
112 # being in a collection.
113 if gmr is not None:
114 # Delete collections found
115 items = CollectionItem.query.filter_by(
116 object_id=gmr.id
118 items.delete()
120 # Delete notifications found
121 notifications = Notification.query.filter_by(
122 object_id=gmr.id
124 notifications.delete()
126 # Set None on reports found
127 reports = Report.query.filter_by(
128 object_id=gmr.id
130 for report in reports:
131 report.object_id = None
132 report.save(commit=commit)
134 # Hand off to the correct deletion function.
135 if deletion == self.HARD_DELETE:
136 return self.hard_delete(commit=commit)
137 elif deletion == self.SOFT_DELETE:
138 return self.soft_delete(commit=commit)
139 else:
140 raise ValueError(
141 "Invalid deletion mode {mode!r}".format(
142 mode=deletion
146 def soft_delete(self, commit):
147 # Create the graveyard version of this model
148 # Importing this here due to cyclic imports
149 from mediagoblin.db.models import User, Graveyard, GenericModelReference
151 tombstone = Graveyard()
152 if getattr(self, "public_id", None) is not None:
153 tombstone.public_id = self.public_id
155 # This is a special case, we don't want to save any actor if the thing
156 # being soft deleted is a User model as this would create circular
157 # ForeignKeys
158 if not isinstance(self, User):
159 tombstone.actor = User.query.filter_by(
160 id=self.actor
161 ).first()
162 tombstone.object_type = self.object_type
163 tombstone.save(commit=False)
165 # There will be a lot of places where the GenericForeignKey will point
166 # to the model, we want to remap those to our tombstone.
167 gmrs = GenericModelReference.query.filter_by(
168 obj_pk=self.id,
169 model_type=self.__tablename__
170 ).update({
171 "obj_pk": tombstone.id,
172 "model_type": tombstone.__tablename__,
176 # Now we can go ahead and actually delete the model.
177 return self.hard_delete(commit=commit)
179 def hard_delete(self, commit):
180 """Delete the object and commit the change immediately by default"""
181 sess = self._session
182 assert sess is not None, "Not going to delete detached %r" % self
183 sess.delete(self)
184 if commit:
185 sess.commit()
188 Base = declarative_base(cls=GMGTableBase)
191 class DictReadAttrProxy(object):
193 Maps read accesses to obj['key'] to obj.key
194 and hides all the rest of the obj
196 def __init__(self, proxied_obj):
197 self.proxied_obj = proxied_obj
199 def __getitem__(self, key):
200 try:
201 return getattr(self.proxied_obj, key)
202 except AttributeError:
203 raise KeyError("%r is not an attribute on %r"
204 % (key, self.proxied_obj))