2 local setmetatable
, getmetatable
= setmetatable
, getmetatable
;
4 local tostring = tostring;
6 local assert, pcall
, debug_traceback
= assert, pcall
, debug
.traceback
;
7 local xpcall
= require
"util.xpcall".xpcall
;
8 local t_concat
= table.concat
;
9 local log = require
"util.logger".init("sql");
11 local DBI
= require
"DBI";
12 -- This loads all available drivers while globals are unlocked
13 -- LuaDBI should be fixed to not set globals.
15 local build_url
= require
"socket.url".build
;
26 local function is_column(x
) return getmetatable(x
)==column_mt
; end
27 local function is_index(x
) return getmetatable(x
)==index_mt
; end
28 local function is_table(x
) return getmetatable(x
)==table_mt
; end
29 local function is_query(x
) return getmetatable(x
)==query_mt
; end
30 local function Integer() return "Integer()" end
31 local function String() return "String()" end
33 local function Column(definition
)
34 return setmetatable(definition
, column_mt
);
36 local function Table(definition
)
38 for i
,col
in ipairs(definition
) do
39 if is_column(col
) then
40 c
[i
], c
[col
.name
] = col
, col
;
41 elseif is_index(col
) then
42 col
.table = definition
.name
;
45 return setmetatable({ __table__
= definition
, c
= c
, name
= definition
.name
}, table_mt
);
47 local function Index(definition
)
48 return setmetatable(definition
, index_mt
);
51 function table_mt
:__tostring()
52 local s
= { 'name="'..self
.__table__
.name
..'"' }
53 for _
, col
in ipairs(self
.__table__
) do
54 s
[#s
+1] = tostring(col
);
56 return 'Table{ '..t_concat(s
, ", ")..' }'
58 table_mt
.__index
= {};
59 function table_mt
.__index
:create(engine
)
60 return engine
:_create_table(self
);
62 function column_mt
:__tostring()
63 return 'Column{ name="'..self
.name
..'", type="'..self
.type..'" }'
65 function index_mt
:__tostring()
66 local s
= 'Index{ name="'..self
.name
..'"';
67 for i
=1,#self
do s
= s
..', "'..self
[i
]:gsub("[\\\"]", "\\%1")..'"'; end
69 -- return 'Index{ name="'..self.name..'", type="'..self.type..'" }'
73 function engine
:connect()
74 if self
.conn
then return true; end
76 local params
= self
.params
;
77 assert(params
.driver
, "no driver")
78 log("debug", "Connecting to [%s] %s...", params
.driver
, params
.database
);
79 local ok
, dbh
, err
= pcall(DBI
.Connect
,
80 params
.driver
, params
.database
,
81 params
.username
, params
.password
,
82 params
.host
, params
.port
84 if not ok
then return ok
, dbh
; end
85 if not dbh
then return nil, err
; end
86 dbh
:autocommit(false); -- don't commit automatically
89 local ok
, err
= self
:set_encoding();
93 local ok
, err
= self
:onconnect();
99 function engine
:onconnect() -- luacheck: ignore 212/self
100 -- Override from create_engine()
103 function engine
:prepquery(sql
)
104 if self
.params
.driver
== "MySQL" then
105 sql
= sql
:gsub("\"", "`");
110 function engine
:execute(sql
, ...)
111 local success
, err
= self
:connect();
112 if not success
then return success
, err
; end
113 local prepared
= self
.prepared
;
115 sql
= self
:prepquery(sql
);
116 local stmt
= prepared
[sql
];
119 stmt
, err
= self
.conn
:prepare(sql
);
120 if not stmt
then return stmt
, err
; end
121 prepared
[sql
] = stmt
;
124 -- luacheck: ignore 411/success
125 local success
, err
= stmt
:execute(...);
126 if not success
then return success
, err
; end
130 local result_mt
= { __index
= {
131 affected
= function(self
) return self
.__stmt
:affected(); end;
132 rowcount
= function(self
) return self
.__stmt
:rowcount(); end;
135 local function debugquery(where
, sql
, ...)
136 local i
= 0; local a
= {...}
137 sql
= sql
:gsub("\n?\t+", " ");
138 log("debug", "[%s] %s", where
, (sql
:gsub("%?", function ()
141 if type(v
) == "string" then
142 v
= ("'%s'"):format(v
:gsub("'", "''"));
148 function engine
:execute_query(sql
, ...)
149 sql
= self
:prepquery(sql
);
150 local stmt
= assert(self
.conn
:prepare(sql
));
151 assert(stmt
:execute(...));
153 for row
in stmt
:rows() do result
[#result
+ 1] = row
; end
156 return function() i
=i
+1; return result
[i
]; end;
158 function engine
:execute_update(sql
, ...)
159 sql
= self
:prepquery(sql
);
160 local prepared
= self
.prepared
;
161 local stmt
= prepared
[sql
];
163 stmt
= assert(self
.conn
:prepare(sql
));
164 prepared
[sql
] = stmt
;
166 assert(stmt
:execute(...));
167 return setmetatable({ __stmt
= stmt
}, result_mt
);
169 engine
.insert
= engine
.execute_update
;
170 engine
.select
= engine
.execute_query
;
171 engine
.delete
= engine
.execute_update
;
172 engine
.update
= engine
.execute_update
;
173 local function debugwrap(name
, f
)
174 return function (self
, sql
, ...)
175 debugquery(name
, sql
, ...)
176 return f(self
, sql
, ...)
179 function engine
:debug(enable
)
180 self
._debug
= enable
;
182 engine
.insert
= debugwrap("insert", engine
.execute_update
);
183 engine
.select
= debugwrap("select", engine
.execute_query
);
184 engine
.delete
= debugwrap("delete", engine
.execute_update
);
185 engine
.update
= debugwrap("update", engine
.execute_update
);
187 engine
.insert
= engine
.execute_update
;
188 engine
.select
= engine
.execute_query
;
189 engine
.delete
= engine
.execute_update
;
190 engine
.update
= engine
.execute_update
;
193 local function handleerr(err
)
194 local trace
= debug_traceback(err
, 3);
195 log("debug", "Error in SQL transaction: %s", trace
);
196 return { err
= err
, traceback
= trace
};
198 function engine
:_transaction(func
, ...)
199 if not self
.conn
then
200 local ok
, err
= self
:connect();
201 if not ok
then return ok
, err
; end
203 --assert(not self.__transaction, "Recursive transactions not allowed");
204 log("debug", "SQL transaction begin [%s]", func
);
205 self
.__transaction
= true;
206 local success
, a
, b
, c
= xpcall(func
, handleerr
, ...);
207 self
.__transaction
= nil;
209 log("debug", "SQL transaction success [%s]", func
);
210 local ok
, err
= self
.conn
:commit();
211 -- LuaDBI doesn't actually return an error message here, just a boolean
212 if not ok
then return ok
, err
or "commit failed"; end
213 return success
, a
, b
, c
;
215 log("debug", "SQL transaction failure [%s]: %s", func
, a
.err
);
216 if self
.conn
then self
.conn
:rollback(); end
217 return success
, a
.err
;
220 function engine
:transaction(...)
221 local ok
, ret
= self
:_transaction(...);
223 local conn
= self
.conn
;
224 if not conn
or not conn
:ping() then
225 log("debug", "Database connection was closed. Will reconnect and retry.");
227 log("debug", "Retrying SQL transaction [%s]", (...));
228 ok
, ret
= self
:_transaction(...);
229 log("debug", "SQL transaction retry %s", ok
and "succeeded" or "failed");
231 log("debug", "SQL connection is up, so not retrying");
234 log("error", "Error in SQL transaction: %s", ret
);
239 function engine
:_create_index(index
)
240 local sql
= "CREATE INDEX \""..index
.name
.."\" ON \""..index
.table.."\" (";
241 if self
.params
.driver
~= "MySQL" then
242 sql
= sql
:gsub("^CREATE INDEX", "%1 IF NOT EXISTS");
245 sql
= sql
.."\""..index
[i
].."\"";
246 if i
~= #index
then sql
= sql
..", "; end
249 if self
.params
.driver
== "MySQL" then
250 sql
= sql
:gsub("\"([,)])", "\"(20)%1");
253 sql
= sql
:gsub("^CREATE", "CREATE UNIQUE");
256 debugquery("create", sql
);
258 return self
:execute(sql
);
260 function engine
:_create_table(table)
261 local sql
= "CREATE TABLE \""..table.name
.."\" (";
263 sql
= sql
:gsub("^CREATE TABLE", "%1 IF NOT EXISTS");
265 for i
,col
in ipairs(table.c
) do
266 local col_type
= col
.type;
267 if col_type
== "MEDIUMTEXT" and self
.params
.driver
~= "MySQL" then
268 col_type
= "TEXT"; -- MEDIUMTEXT is MySQL-specific
270 if col
.auto_increment
== true and self
.params
.driver
== "PostgreSQL" then
271 col_type
= "BIGSERIAL";
273 sql
= sql
.."\""..col
.name
.."\" "..col_type
;
274 if col
.nullable
== false then sql
= sql
.." NOT NULL"; end
275 if col
.primary_key
== true then sql
= sql
.." PRIMARY KEY"; end
276 if col
.auto_increment
== true then
277 if self
.params
.driver
== "MySQL" then
278 sql
= sql
.." AUTO_INCREMENT";
279 elseif self
.params
.driver
== "SQLite3" then
280 sql
= sql
.." AUTOINCREMENT";
283 if i
~= #table.c
then sql
= sql
..", "; end
286 if self
.params
.driver
== "MySQL" then
287 sql
= sql
:gsub(";$", (" CHARACTER SET '%s' COLLATE '%s_bin';"):format(self
.charset
, self
.charset
));
290 debugquery("create", sql
);
292 local success
,err
= self
:execute(sql
);
293 if not success
then return success
,err
; end
294 for _
, v
in ipairs(table.__table__
) do
296 self
:_create_index(v
);
301 function engine
:set_encoding() -- to UTF-8
302 local driver
= self
.params
.driver
;
303 if driver
== "SQLite3" then
304 return self
:transaction(function()
305 for encoding
in self
:select
"PRAGMA encoding;" do
306 if encoding
[1] == "UTF-8" then
307 self
.charset
= "utf8";
312 local set_names_query
= "SET NAMES '%s';"
313 local charset
= "utf8";
314 if driver
== "MySQL" then
315 self
:transaction(function()
316 for row
in self
:select
[[
317 SELECT "CHARACTER_SET_NAME"
318 FROM "information_schema"."CHARACTER_SETS"
319 WHERE "CHARACTER_SET_NAME" LIKE 'utf8%'
320 ORDER BY MAXLEN DESC LIMIT 1;
322 charset
= row
and row
[1] or charset
;
325 set_names_query
= set_names_query
:gsub(";$", (" COLLATE '%s';"):format(charset
.."_bin"));
327 self
.charset
= charset
;
328 log("debug", "Using encoding '%s' for database connection", charset
);
329 local ok
, err
= self
:transaction(function() return self
:execute(set_names_query
:format(charset
)); end);
334 if driver
== "MySQL" then
335 local ok
, actual_charset
= self
:transaction(function ()
336 return self
:select
"SHOW SESSION VARIABLES LIKE 'character_set_client'";
338 local charset_ok
= true;
339 for row
in actual_charset
do
340 if row
[2] ~= charset
then
341 log("error", "MySQL %s is actually %q (expected %q)", row
[1], row
[2], charset
);
345 if not charset_ok
then
346 return false, "Failed to set connection encoding";
352 local engine_mt
= { __index
= engine
};
354 local function db2uri(params
)
356 scheme
= params
.driver
,
357 user
= params
.username
,
358 password
= params
.password
,
361 path
= params
.database
,
365 local function create_engine(_
, params
, onconnect
)
366 return setmetatable({ url
= db2uri(params
), params
= params
, onconnect
= onconnect
}, engine_mt
);
370 is_column
= is_column
;
379 create_engine
= create_engine
;