trac#687: Add unit tests for `redirect` and `redirect_obj`.
[larjonas-mediagoblin.git] / mediagoblin / db / base.py
blob11afbcec5726210c5289e9ef59c41471d71ef819
1 # GNU MediaGoblin -- federated, autonomous media hosting
2 # Copyright (C) 2011, 2012 MediaGoblin contributors. See AUTHORS.
4 # This program is free software: you can redistribute it and/or modify
5 # it under the terms of the GNU Affero General Public License as published by
6 # the Free Software Foundation, either version 3 of the License, or
7 # (at your option) any later version.
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU Affero General Public License for more details.
14 # You should have received a copy of the GNU Affero General Public License
15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 import six
17 import copy
19 from sqlalchemy.ext.declarative import declarative_base
20 from sqlalchemy import inspect
22 from mediagoblin.tools.transition import DISABLE_GLOBALS
24 if not DISABLE_GLOBALS:
25 from sqlalchemy.orm import scoped_session, sessionmaker
26 Session = scoped_session(sessionmaker())
28 class FakeCursor(object):
30 def __init__ (self, cursor, mapper, filter=None):
31 self.cursor = cursor
32 self.mapper = mapper
33 self.filter = filter
35 def count(self):
36 return self.cursor.count()
38 def __copy__(self):
39 # Or whatever the function is named to make
40 # copy.copy happy?
41 return FakeCursor(copy.copy(self.cursor), self.mapper, self.filter)
43 def __iter__(self):
44 return six.moves.filter(self.filter, six.moves.map(self.mapper, self.cursor))
46 def __getitem__(self, key):
47 return self.mapper(self.cursor[key])
49 def slice(self, *args, **kwargs):
50 r = self.cursor.slice(*args, **kwargs)
51 return list(six.moves.filter(self.filter, six.moves.map(self.mapper, r)))
53 class GMGTableBase(object):
54 # Deletion types
55 HARD_DELETE = "hard-deletion"
56 SOFT_DELETE = "soft-deletion"
58 deletion_mode = HARD_DELETE
60 @property
61 def _session(self):
62 return inspect(self).session
64 @property
65 def _app(self):
66 return self._session.bind.app
68 if not DISABLE_GLOBALS:
69 query = Session.query_property()
71 def get(self, key):
72 return getattr(self, key)
74 def setdefault(self, key, defaultvalue):
75 # The key *has* to exist on sql.
76 return getattr(self, key)
78 def save(self, commit=True):
79 sess = self._session
80 if sess is None and not DISABLE_GLOBALS:
81 sess = Session()
82 assert sess is not None, "Can't save, %r has a detached session" % self
83 sess.add(self)
84 if commit:
85 sess.commit()
86 else:
87 sess.flush()
89 def delete(self, commit=True, deletion=None):
90 """ Delete the object either using soft or hard deletion """
91 # Get the setting in the model args if none has been specified.
92 if deletion is None:
93 deletion = self.deletion_mode
95 # Hand off to the correct deletion function.
96 if deletion == self.HARD_DELETE:
97 return self.hard_delete(commit=commit)
98 elif deletion == self.SOFT_DELETE:
99 return self.soft_delete(commit=commit)
100 else:
101 raise ValueError(
102 "Invalid deletion mode {mode!r}".format(
103 mode=deletion
107 def soft_delete(self, commit):
108 # Create the graveyard version of this model
109 # Importing this here due to cyclic imports
110 from mediagoblin.db.models import User, Graveyard, GenericModelReference
111 tombstone = Graveyard()
112 if getattr(self, "public_id", None) is not None:
113 tombstone.public_id = self.public_id
115 # This is a special case, we don't want to save any actor if the thing
116 # being soft deleted is a User model as this would create circular
117 # ForeignKeys
118 if not isinstance(self, User):
119 tombstone.actor = User.query.filter_by(
120 id=self.actor
121 ).first()
122 tombstone.object_type = self.object_type
123 tombstone.save(commit=False)
125 # There will be a lot of places where the GenericForeignKey will point
126 # to the model, we want to remap those to our tombstone.
127 gmrs = GenericModelReference.query.filter_by(
128 obj_pk=self.id,
129 model_type=self.__tablename__
130 ).update({
131 "obj_pk": tombstone.id,
132 "model_type": tombstone.__tablename__,
135 # Now we can go ahead and actually delete the model.
136 return self.hard_delete(commit=commit)
138 def hard_delete(self, commit):
139 """Delete the object and commit the change immediately by default"""
140 sess = self._session
141 assert sess is not None, "Not going to delete detached %r" % self
142 sess.delete(self)
143 if commit:
144 sess.commit()
147 Base = declarative_base(cls=GMGTableBase)
150 class DictReadAttrProxy(object):
152 Maps read accesses to obj['key'] to obj.key
153 and hides all the rest of the obj
155 def __init__(self, proxied_obj):
156 self.proxied_obj = proxied_obj
158 def __getitem__(self, key):
159 try:
160 return getattr(self.proxied_obj, key)
161 except AttributeError:
162 raise KeyError("%r is not an attribute on %r"
163 % (key, self.proxied_obj))