mod_s2s: Handle authentication of s2sin and s2sout the same way
[prosody.git] / util / sql.lua
blob4406d7ff33380ebd5901cd9fcaf6b2e544fb6397
2 local setmetatable, getmetatable = setmetatable, getmetatable;
3 local ipairs = ipairs;
4 local tostring = tostring;
5 local type = type;
6 local assert, pcall, debug_traceback = assert, pcall, debug.traceback;
7 local xpcall = require "util.xpcall".xpcall;
8 local t_concat = table.concat;
9 local log = require "util.logger".init("sql");
11 local DBI = require "DBI";
12 -- This loads all available drivers while globals are unlocked
13 -- LuaDBI should be fixed to not set globals.
14 DBI.Drivers();
15 local build_url = require "socket.url".build;
17 local _ENV = nil;
18 -- luacheck: std none
20 local column_mt = {};
21 local table_mt = {};
22 local query_mt = {};
23 --local op_mt = {};
24 local index_mt = {};
26 local function is_column(x) return getmetatable(x)==column_mt; end
27 local function is_index(x) return getmetatable(x)==index_mt; end
28 local function is_table(x) return getmetatable(x)==table_mt; end
29 local function is_query(x) return getmetatable(x)==query_mt; end
30 local function Integer() return "Integer()" end
31 local function String() return "String()" end
33 local function Column(definition)
34 return setmetatable(definition, column_mt);
35 end
36 local function Table(definition)
37 local c = {}
38 for i,col in ipairs(definition) do
39 if is_column(col) then
40 c[i], c[col.name] = col, col;
41 elseif is_index(col) then
42 col.table = definition.name;
43 end
44 end
45 return setmetatable({ __table__ = definition, c = c, name = definition.name }, table_mt);
46 end
47 local function Index(definition)
48 return setmetatable(definition, index_mt);
49 end
51 function table_mt:__tostring()
52 local s = { 'name="'..self.__table__.name..'"' }
53 for _, col in ipairs(self.__table__) do
54 s[#s+1] = tostring(col);
55 end
56 return 'Table{ '..t_concat(s, ", ")..' }'
57 end
58 table_mt.__index = {};
59 function table_mt.__index:create(engine)
60 return engine:_create_table(self);
61 end
62 function column_mt:__tostring()
63 return 'Column{ name="'..self.name..'", type="'..self.type..'" }'
64 end
65 function index_mt:__tostring()
66 local s = 'Index{ name="'..self.name..'"';
67 for i=1,#self do s = s..', "'..self[i]:gsub("[\\\"]", "\\%1")..'"'; end
68 return s..' }';
69 -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
70 end
72 local engine = {};
73 function engine:connect()
74 if self.conn then return true; end
76 local params = self.params;
77 assert(params.driver, "no driver")
78 log("debug", "Connecting to [%s] %s...", params.driver, params.database);
79 local ok, dbh, err = pcall(DBI.Connect,
80 params.driver, params.database,
81 params.username, params.password,
82 params.host, params.port
84 if not ok then return ok, dbh; end
85 if not dbh then return nil, err; end
86 dbh:autocommit(false); -- don't commit automatically
87 self.conn = dbh;
88 self.prepared = {};
89 local ok, err = self:set_encoding();
90 if not ok then
91 return ok, err;
92 end
93 local ok, err = self:onconnect();
94 if ok == false then
95 return ok, err;
96 end
97 return true;
98 end
99 function engine:onconnect() -- luacheck: ignore 212/self
100 -- Override from create_engine()
103 function engine:prepquery(sql)
104 if self.params.driver == "MySQL" then
105 sql = sql:gsub("\"", "`");
107 return sql;
110 function engine:execute(sql, ...)
111 local success, err = self:connect();
112 if not success then return success, err; end
113 local prepared = self.prepared;
115 sql = self:prepquery(sql);
116 local stmt = prepared[sql];
117 if not stmt then
118 local err;
119 stmt, err = self.conn:prepare(sql);
120 if not stmt then return stmt, err; end
121 prepared[sql] = stmt;
124 -- luacheck: ignore 411/success
125 local success, err = stmt:execute(...);
126 if not success then return success, err; end
127 return stmt;
130 local result_mt = { __index = {
131 affected = function(self) return self.__stmt:affected(); end;
132 rowcount = function(self) return self.__stmt:rowcount(); end;
133 } };
135 local function debugquery(where, sql, ...)
136 local i = 0; local a = {...}
137 sql = sql:gsub("\n?\t+", " ");
138 log("debug", "[%s] %s", where, (sql:gsub("%?", function ()
139 i = i + 1;
140 local v = a[i];
141 if type(v) == "string" then
142 v = ("'%s'"):format(v:gsub("'", "''"));
144 return tostring(v);
145 end)));
148 function engine:execute_query(sql, ...)
149 sql = self:prepquery(sql);
150 local stmt = assert(self.conn:prepare(sql));
151 assert(stmt:execute(...));
152 local result = {};
153 for row in stmt:rows() do result[#result + 1] = row; end
154 stmt:close();
155 local i = 0;
156 return function() i=i+1; return result[i]; end;
158 function engine:execute_update(sql, ...)
159 sql = self:prepquery(sql);
160 local prepared = self.prepared;
161 local stmt = prepared[sql];
162 if not stmt then
163 stmt = assert(self.conn:prepare(sql));
164 prepared[sql] = stmt;
166 assert(stmt:execute(...));
167 return setmetatable({ __stmt = stmt }, result_mt);
169 engine.insert = engine.execute_update;
170 engine.select = engine.execute_query;
171 engine.delete = engine.execute_update;
172 engine.update = engine.execute_update;
173 local function debugwrap(name, f)
174 return function (self, sql, ...)
175 debugquery(name, sql, ...)
176 return f(self, sql, ...)
179 function engine:debug(enable)
180 self._debug = enable;
181 if enable then
182 engine.insert = debugwrap("insert", engine.execute_update);
183 engine.select = debugwrap("select", engine.execute_query);
184 engine.delete = debugwrap("delete", engine.execute_update);
185 engine.update = debugwrap("update", engine.execute_update);
186 else
187 engine.insert = engine.execute_update;
188 engine.select = engine.execute_query;
189 engine.delete = engine.execute_update;
190 engine.update = engine.execute_update;
193 local function handleerr(err)
194 local trace = debug_traceback(err, 3);
195 log("debug", "Error in SQL transaction: %s", trace);
196 return { err = err, traceback = trace };
198 function engine:_transaction(func, ...)
199 if not self.conn then
200 local ok, err = self:connect();
201 if not ok then return ok, err; end
203 --assert(not self.__transaction, "Recursive transactions not allowed");
204 log("debug", "SQL transaction begin [%s]", func);
205 self.__transaction = true;
206 local success, a, b, c = xpcall(func, handleerr, ...);
207 self.__transaction = nil;
208 if success then
209 log("debug", "SQL transaction success [%s]", func);
210 local ok, err = self.conn:commit();
211 -- LuaDBI doesn't actually return an error message here, just a boolean
212 if not ok then return ok, err or "commit failed"; end
213 return success, a, b, c;
214 else
215 log("debug", "SQL transaction failure [%s]: %s", func, a.err);
216 if self.conn then self.conn:rollback(); end
217 return success, a.err;
220 function engine:transaction(...)
221 local ok, ret = self:_transaction(...);
222 if not ok then
223 local conn = self.conn;
224 if not conn or not conn:ping() then
225 log("debug", "Database connection was closed. Will reconnect and retry.");
226 self.conn = nil;
227 log("debug", "Retrying SQL transaction [%s]", (...));
228 ok, ret = self:_transaction(...);
229 log("debug", "SQL transaction retry %s", ok and "succeeded" or "failed");
230 else
231 log("debug", "SQL connection is up, so not retrying");
233 if not ok then
234 log("error", "Error in SQL transaction: %s", ret);
237 return ok, ret;
239 function engine:_create_index(index)
240 local sql = "CREATE INDEX \""..index.name.."\" ON \""..index.table.."\" (";
241 if self.params.driver ~= "MySQL" then
242 sql = sql:gsub("^CREATE INDEX", "%1 IF NOT EXISTS");
244 for i=1,#index do
245 sql = sql.."\""..index[i].."\"";
246 if i ~= #index then sql = sql..", "; end
248 sql = sql..");"
249 if self.params.driver == "MySQL" then
250 sql = sql:gsub("\"([,)])", "\"(20)%1");
252 if index.unique then
253 sql = sql:gsub("^CREATE", "CREATE UNIQUE");
255 if self._debug then
256 debugquery("create", sql);
258 return self:execute(sql);
260 function engine:_create_table(table)
261 local sql = "CREATE TABLE \""..table.name.."\" (";
263 sql = sql:gsub("^CREATE TABLE", "%1 IF NOT EXISTS");
265 for i,col in ipairs(table.c) do
266 local col_type = col.type;
267 if col_type == "MEDIUMTEXT" and self.params.driver ~= "MySQL" then
268 col_type = "TEXT"; -- MEDIUMTEXT is MySQL-specific
270 if col.auto_increment == true and self.params.driver == "PostgreSQL" then
271 col_type = "BIGSERIAL";
273 sql = sql.."\""..col.name.."\" "..col_type;
274 if col.nullable == false then sql = sql.." NOT NULL"; end
275 if col.primary_key == true then sql = sql.." PRIMARY KEY"; end
276 if col.auto_increment == true then
277 if self.params.driver == "MySQL" then
278 sql = sql.." AUTO_INCREMENT";
279 elseif self.params.driver == "SQLite3" then
280 sql = sql.." AUTOINCREMENT";
283 if i ~= #table.c then sql = sql..", "; end
285 sql = sql.. ");"
286 if self.params.driver == "MySQL" then
287 sql = sql:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self.charset, self.charset));
289 if self._debug then
290 debugquery("create", sql);
292 local success,err = self:execute(sql);
293 if not success then return success,err; end
294 for _, v in ipairs(table.__table__) do
295 if is_index(v) then
296 self:_create_index(v);
299 return success;
301 function engine:set_encoding() -- to UTF-8
302 local driver = self.params.driver;
303 if driver == "SQLite3" then
304 return self:transaction(function()
305 for encoding in self:select"PRAGMA encoding;" do
306 if encoding[1] == "UTF-8" then
307 self.charset = "utf8";
310 end);
312 local set_names_query = "SET NAMES '%s';"
313 local charset = "utf8";
314 if driver == "MySQL" then
315 self:transaction(function()
316 for row in self:select[[
317 SELECT "CHARACTER_SET_NAME"
318 FROM "information_schema"."CHARACTER_SETS"
319 WHERE "CHARACTER_SET_NAME" LIKE 'utf8%'
320 ORDER BY MAXLEN DESC LIMIT 1;
321 ]] do
322 charset = row and row[1] or charset;
324 end);
325 set_names_query = set_names_query:gsub(";$", (" COLLATE '%s';"):format(charset.."_bin"));
327 self.charset = charset;
328 log("debug", "Using encoding '%s' for database connection", charset);
329 local ok, err = self:transaction(function() return self:execute(set_names_query:format(charset)); end);
330 if not ok then
331 return ok, err;
334 if driver == "MySQL" then
335 local ok, actual_charset = self:transaction(function ()
336 return self:select"SHOW SESSION VARIABLES LIKE 'character_set_client'";
337 end);
338 local charset_ok = true;
339 for row in actual_charset do
340 if row[2] ~= charset then
341 log("error", "MySQL %s is actually %q (expected %q)", row[1], row[2], charset);
342 charset_ok = false;
345 if not charset_ok then
346 return false, "Failed to set connection encoding";
350 return true;
352 local engine_mt = { __index = engine };
354 local function db2uri(params)
355 return build_url{
356 scheme = params.driver,
357 user = params.username,
358 password = params.password,
359 host = params.host,
360 port = params.port,
361 path = params.database,
365 local function create_engine(_, params, onconnect)
366 return setmetatable({ url = db2uri(params), params = params, onconnect = onconnect }, engine_mt);
369 return {
370 is_column = is_column;
371 is_index = is_index;
372 is_table = is_table;
373 is_query = is_query;
374 Integer = Integer;
375 String = String;
376 Column = Column;
377 Table = Table;
378 Index = Index;
379 create_engine = create_engine;
380 db2uri = db2uri;