From 324855967ef23123abbd84fa4913bcb99fb5922a Mon Sep 17 00:00:00 2001 From: Pixel Date: Fri, 3 Oct 2008 10:55:24 -0700 Subject: Adding dblib.lua --- lib/dblib.lua | 519 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 519 insertions(+) create mode 100644 lib/dblib.lua (limited to 'lib') diff --git a/lib/dblib.lua b/lib/dblib.lua new file mode 100644 index 0000000..6469d9c --- /dev/null +++ b/lib/dblib.lua @@ -0,0 +1,519 @@ +--[[ + +/* + * Baltisot + * Copyright (C) 1999-2008 Nicolas "Pixel" Noble + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +]]-- + +-- dblib - database object layer +-- Currently built-in with MySQL - TODO: support some abstraction layer with drivers. + +--[[ + +ddl = { + field1 = { type = "varchar", length = 32 }, + field2 = { type = "int", length = 11, decimal = 8, default = 42 }, + field3 = { options = "pri,auto" }, +} + +]]-- + +--local _luadb +_luadb = { + get_options = function(str) + local s = split(str, ",") + local r = {} + for k, v in pairs(s) do + r[v:upper()] = true + end + + return r + end, + + get_canon_type = function(ttype, k) + if ttype == nil then + return "INT" + elseif type(ttype) ~= "string" then + error("Wrong data in ddl - " .. k .. ".type isn't a string.") + elseif ttype:upper() == "VARCHAR" then + return "VARCHAR" + elseif ttype:upper() == "INT" then + return "INT" + elseif ttype:upper() == "BLOB" then + return "BLOB" + elseif ttype:upper() == "DATETIME" then + return "DATETIME" + elseif ttype:upper() == "NUMBER" then + return "NUMBER" + else + error("Unknow data type in ddl - " .. k .. ": " .. ttype:upper()) + end + end, + + generate_fields = function(ddl) + local k, v + local r, r2 = {}, {} + + for k, v in pairs(ddl) do + r[k] = "" + r2[k] = {} + + r[k] = _luadb.get_canon_type(v.type, k) + + if v.length == nil then + -- do nothing + elseif type(v.length) ~= "number" then + error("Wrong data in ddl - " .. k .. ".length isn't a number.") + elseif v.decimal == nil then + r[k] = r[k] .. "(" .. v.length .. ")" + elseif type(v.decimal) ~= "number" then + error("Wrong data in ddl - " .. k .. ".decimal isn't a number.") + else + r[k] = r[k] .. "(" .. v.length .. "," .. v.decimal .. ")" + end + + if v.options ~= nil and type(v.options) ~= "string" then + error("Wrong data in ddl - " .. k .. ".options isn't a string.") + end + + local options = v.options and _luadb.get_options(v.options) or {} + + if options.NULL then + r[k] = r[k] .. " NULL" + else + r[k] = r[k] .. " NOT NULL" + end + + if options.PRI then + r[k] = r[k] .. " PRIMARY KEY" + end + + if options.UNIQ then + table.insert(r2[k], "ADD UNIQUE (`@fieldname@`)") + end + + if options.AUTO then + r[k] = r[k] .. " AUTO_INCREMENT" + end + end + return r, r2 + end, + + opentable = function(db, tablename, ddl) + local fields, alters = _luadb.generate_fields(ddl) + local tname = "`" .. db._.prefix .. db.sql_escape(tablename) .. "`" + + if db:SafeQuery("DESC " .. tname) ~= 0 then + -- table doesn't exist, create it + if db._.conn:ErrNO() == 1146 then + local q = "CREATE TABLE " .. tname .. " (" + local k, v, first + first = true + + for k, v in pairs(fields) do + if not first then + q = q .. ", " + end + q = q .. "`" .. db.sql_escape(k) .. "` " .. v + end + + q = q .. ");" + + if db:SafeQuery(q) ~= 0 then + error("Error creating table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + end + + for k, v in pairs(alters) do + local _, stmt + for _, stmt in pairs(v) do + q = "ALTER TABLE " .. tname .. " " .. string.gsub(stmt, "@fieldname@", db.sql_escape(k)) + if db:SafeQuery(q) ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + db:SafeQuery("DROP TABLE " .. tname ";") + end + end + end + else + error("Error getting description of table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error()) + end + else + -- table exists, let's check it. + local dfields = {} + local i + local any_common = false + + for i = 1, db._.conn:NumRows() do + dfields[i] = db._.conn:FetchRow() + if fields[dfields[i].Field] then + any_common = true + end + end + + if not any_common then + -- strictly no field in common; drop the table and recursively call opentable in order to proceed with a createtable instead. + q = "DROP TABLE " .. tname .. ";" + if db:SafeQuery(q) ~= 0 then + error("Error dropping table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + end + return db:opentable(tablename, ddl) + end + + local d, k, v = {} + + for k, v in pairs(dfields) do + d[v.Field] = v + if not fields[v.Field] then + q = "ALTER TABLE " .. tname .. " DROP COLUMN `" .. db.sql_escape(v.Field) .. "`;" + if db:SafeQuery(q) ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + end + end + end + dfields, d = d, nil + + local deffered_alters = {} + + for k, v in pairs(fields) do + q = nil + if not dfields[k] then + q = "ALTER TABLE " .. tname .. " ADD `" .. db.sql_escape(k) .. "` " .. v .. ";" + else + local identicals, alters = _luadb.compare_fields(ddl, dfields[k]) + if not identicals then + q = "ALTER TABLE " .. tname .. " MODIFY `" .. db.sql_escape(k) .. "` " .. v .. ";" + local _, a + for _, a in pairs(alters) do + a.stmt = string.gsub(a.stmt, "@fieldname@", db.sql_escape(k)) + if a.pri < 0 then + table.insert(deffered_alters, a) + else + q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";" + if db:SafeQuery(q) ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + end + end + end + end + end + if q then + if db:SafeQuery(q) ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + end + end + end + + table.sort(deffered_alters, function(a, b) return a.pri < b.pri end) + + for k, v in pairs(deffered_alters) do + q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";" + if db:SafeQuery(q) ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + end + end + end + + return { + insert = _luadb.insert, + delete = _luadb.delete, + restrict = _luadb.restrict, + restricter = _luadb.restricter, + _ = { + tablename = tablename, + tname = tname, + db = db, + ddl = ddl, + conditions = "1", + }, + } + end, + + compare_fields = function(ddl, field) + local f = field.Field + local identical = true + local alters = {} + if type(ddl[f]) ~= "table" then + error("Internal issue with the DDL: key " .. f .. " isn't a table.") + end + -- check base type first. + local ddltype = _luadb.get_canon_type(ddl[f].type, f) + local desctype, desclength, descdecim = field.Type:match "(%w+)%((%d+)%.(%d+)%)" + if desctype == nil then + desctype, desclength = field.Type:match "(%w+)%((%d+)%)" + end + if desctype == nil then + error("Error parsing type string from database: " .. f .. ": " .. field.Type) + end + desctype = _luadb.get_canon_type(desctype, f) + + if ddltype ~= desctype then + identical = false + end + + if field.length and field.length ~= desclength then + identical = false + end + + if field.decimal and field.decimal ~= descdecim then + identical = false + end + + local options = ddl[f].options and _luadb.get_options(ddl[f].options) or {} + + if options.PRI and field.Key ~= "PRI" then + -- needs to add a primary key here, with a specific ALTER TABLE statement + identical = false + table.insert(alters, { stmt = "ADD PRIMARY KEY(@fieldname@)", pri = -2 } ) + end + + if not options.PRI and field.Key == "PRI" then + -- needs to drop the primary key here. + identical = false + table.insert(alters, { stmt = "DROP PRIMARY KEY", pri = -1 } ) + end + + if options.UNI and field.Key ~= "UNI" then + -- needs to add a unique index + identical = false + table.insert(alters, { stmt = "ADD UNIQUE(@fieldname@)", pri = 0 } ) + end + + if not options.UNI and field.Key == "UNI" then + -- needs to drop the unique index + identical = false + table.insert(alters, { stmt = "DROP INDEX @fieldname@", pri = 0 } ) + end + + if not options.AUTO and field.Extra == "auto_increment" then + identical = false + end + + if options.AUTO and field.Extra ~= "auto_increment" then + identical = false + end + + return identical, alters + end, + + insert = function(t, data) + local k, v, stmt, stmt2 + local first = true + + stmt = "INSERT INTO " .. t._.tname .. " (" + stmt2 = ") VALUES (" + + for k, v in pairs(data) do + if not t._.ddl[k] then + error("Trying to address a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) + end + if not first then + stmt = stmt .. ", " + stmt2 = stmt2 .. ", " + else + first = false + end + stmt = stmt .. "`" .. k .. "`" + stmt2 = stmt2 .. "'" .. v .. "'" + end + + if t._.db:SafeQuery(stmt .. stmt2 .. ");") ~= 0 then + error("Error inserting a row inside table " .. t._.tablename .. ": " .. t._.db:Error()) + end + + return t._.db._.conn:InsertId() + end, + + delete = function(t) + local stmt = "DELETE FROM " .. t._.tname .. " WHERE " .. t._.conditions .. ";" + + if t._.db:SafeQuery(stmt) ~= 0 then + error("Error deleting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) + end + end, + + update = function(t, data) + local stmt = "UPDATE " .. t._.tname .. " SET " + local first = true + local k, v + + for k, v in pairs(data) do + if t._.ddl[k] then + error("Trying to update a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) + end + if not first then + stmt = stmt .. ", " + else + first = false + end + stmt = stmt .. "`" .. k .. "`=" + if type(v) == "string" then + stmt = stmt .. '"' .. t._.db.sql_escape(v) .. '"' + else + error("Complex UPDATE queries are not supported yet.") + end + end + + stmt = stmt .. " WHERE " .. t._.conditions + + if t._.db:SafeQuery(stmt) ~= 0 then + error("Error updating row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) + end + end, + + select = function(t, fields) + local stmt = "SELECT " + local first = true + local k, v + + for k, v in pairs(fields) do + if t._.ddl[v] then + error("Trying to select a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) + end + if not first then + stmt = stmt .. ", " + else + first = false + end + stmt = stmt .. "`" .. v .. "`" + end + + stmt = stmt .. " FROM " .. t._.tname .. " WHERE " .. t._.conditions + + if t._.db:SafeQuery(stmt) ~= 0 then + error("Error selecting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) + end + + return _luadb.nextrow, t + end, + + nextrow = function(t) + return t._.db._.conn:FetchRow() + end, + + _restricts = { + -- All the rendering is done on-the-fly. + -- this is the leaf function; thus this is the only one which ough to get the sql_escape. + dop = function(r, field, expr, op) + if not r._.tbl._.ddl[field] then + error("Trying to restrict on a field which doesn't exist: " .. field .. " - inside of table " .. r._.tbl._.tablename) + end + + if type(expr) == "string" or type(expr) == "number" then + return "`" .. field .. "` " .. op .. ' "' .. t._.db.sql_escape(expr) .. '"' + elseif type(expr) == "table" then + error("Complex queries not handled for now") + else + error("Wrong type for expression in operator: " .. type(expr)) + end + end, + + varops = function(r, op, ...) + local exprs, first, r, k, v = {...}, true + + for k, v in pairs(exprs) do + if not first then + r = r .. " " .. op .. " " + end + r = r .. v + first = false + end + end, + }, + + restricts = { + NOT = function(r, expr) + return "NOT " .. expr + end, + + AND = function(r, ...) return _luadb._restricts.varops(r, "AND", ...) end, + OR = function(r, ...) return _luadb._restricts.varops(r, "OR", ...) end, + + EQ = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "=") end, + NEQ = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "!=") end, + GT = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, ">") end, + LT = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "<") end, + GE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, ">=") end, + LE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "<=") end, + LIKE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "LIKE") end, + IN = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "IN") end, + }, + + restricter = function(t) + local r, k, v = { _ = { tbl = t } } + + for k, v in pairs(_luadb.restricts) do + r[k] = v + end + + return r + end, + + restrict = function(t, conditions) + -- all in all, "conditions" can just be a direct string. Let the programmer do what he wants, after all. + if conditions == nil or conditions == "" then conditions = "1" end + t._.conditions = conditions + end, +} + +luadb = { + SafeQuery = function(db, str) + local r = db._.conn:Query(str) + if luadb.debugmode then + print(str) + end + + if r ~= 0 and db._.conn:ErrNO() == 2006 then -- disconnected + db._.conn = db:opendb() + r = db._.conn:Query(str) + end + + return r + end, + + opendb = function(id, user, password, base, prefix) + local db + +-- if user == nil and password == nil and base == nil then + -- get into a registry to fetch "id" +-- else +-- end + if type(id) == "table" and type(id._) == "table" then + user = id._.user + password = id._.password + base = id._.base + prefix = id._.prefix + id = id._.id + end + db = SQLConnection(id, user, password, base) + return { + opentable = _luadb.opentable, + SafeQuery = luadb.SafeQuery, + sql_escape = sql_escape, + ErrNO = function(db) return db._.conn:ErrNO() end, + Error = function(db) return db._.conn:Error() end, + _ = { + conn = db, + id = id, + user = user, + password = password, + base = base, + prefix = sql_escape(prefix or ""), + }, + } + end, +} -- cgit v1.2.3