summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/dblib.lua519
1 files changed, 519 insertions, 0 deletions
diff --git a/lib/dblib.lua b/lib/dblib.lua
new file mode 100644
index 0000000..6469d9c
--- /dev/null
+++ b/lib/dblib.lua
@@ -0,0 +1,519 @@
+--[[
+
+/*
+ * Baltisot
+ * Copyright (C) 1999-2008 Nicolas "Pixel" Noble
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation; either version 2 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program; if not, write to the Free Software
+ * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
+ */
+
+]]--
+
+-- dblib - database object layer
+-- Currently built-in with MySQL - TODO: support some abstraction layer with drivers.
+
+--[[
+
+ddl = {
+ field1 = { type = "varchar", length = 32 },
+ field2 = { type = "int", length = 11, decimal = 8, default = 42 },
+ field3 = { options = "pri,auto" },
+}
+
+]]--
+
+--local _luadb
+_luadb = {
+ get_options = function(str)
+ local s = split(str, ",")
+ local r = {}
+ for k, v in pairs(s) do
+ r[v:upper()] = true
+ end
+
+ return r
+ end,
+
+ get_canon_type = function(ttype, k)
+ if ttype == nil then
+ return "INT"
+ elseif type(ttype) ~= "string" then
+ error("Wrong data in ddl - " .. k .. ".type isn't a string.")
+ elseif ttype:upper() == "VARCHAR" then
+ return "VARCHAR"
+ elseif ttype:upper() == "INT" then
+ return "INT"
+ elseif ttype:upper() == "BLOB" then
+ return "BLOB"
+ elseif ttype:upper() == "DATETIME" then
+ return "DATETIME"
+ elseif ttype:upper() == "NUMBER" then
+ return "NUMBER"
+ else
+ error("Unknow data type in ddl - " .. k .. ": " .. ttype:upper())
+ end
+ end,
+
+ generate_fields = function(ddl)
+ local k, v
+ local r, r2 = {}, {}
+
+ for k, v in pairs(ddl) do
+ r[k] = ""
+ r2[k] = {}
+
+ r[k] = _luadb.get_canon_type(v.type, k)
+
+ if v.length == nil then
+ -- do nothing
+ elseif type(v.length) ~= "number" then
+ error("Wrong data in ddl - " .. k .. ".length isn't a number.")
+ elseif v.decimal == nil then
+ r[k] = r[k] .. "(" .. v.length .. ")"
+ elseif type(v.decimal) ~= "number" then
+ error("Wrong data in ddl - " .. k .. ".decimal isn't a number.")
+ else
+ r[k] = r[k] .. "(" .. v.length .. "," .. v.decimal .. ")"
+ end
+
+ if v.options ~= nil and type(v.options) ~= "string" then
+ error("Wrong data in ddl - " .. k .. ".options isn't a string.")
+ end
+
+ local options = v.options and _luadb.get_options(v.options) or {}
+
+ if options.NULL then
+ r[k] = r[k] .. " NULL"
+ else
+ r[k] = r[k] .. " NOT NULL"
+ end
+
+ if options.PRI then
+ r[k] = r[k] .. " PRIMARY KEY"
+ end
+
+ if options.UNIQ then
+ table.insert(r2[k], "ADD UNIQUE (`@fieldname@`)")
+ end
+
+ if options.AUTO then
+ r[k] = r[k] .. " AUTO_INCREMENT"
+ end
+ end
+ return r, r2
+ end,
+
+ opentable = function(db, tablename, ddl)
+ local fields, alters = _luadb.generate_fields(ddl)
+ local tname = "`" .. db._.prefix .. db.sql_escape(tablename) .. "`"
+
+ if db:SafeQuery("DESC " .. tname) ~= 0 then
+ -- table doesn't exist, create it
+ if db._.conn:ErrNO() == 1146 then
+ local q = "CREATE TABLE " .. tname .. " ("
+ local k, v, first
+ first = true
+
+ for k, v in pairs(fields) do
+ if not first then
+ q = q .. ", "
+ end
+ q = q .. "`" .. db.sql_escape(k) .. "` " .. v
+ end
+
+ q = q .. ");"
+
+ if db:SafeQuery(q) ~= 0 then
+ error("Error creating table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q)
+ end
+
+ for k, v in pairs(alters) do
+ local _, stmt
+ for _, stmt in pairs(v) do
+ q = "ALTER TABLE " .. tname .. " " .. string.gsub(stmt, "@fieldname@", db.sql_escape(k))
+ if db:SafeQuery(q) ~= 0 then
+ error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q)
+ db:SafeQuery("DROP TABLE " .. tname ";")
+ end
+ end
+ end
+ else
+ error("Error getting description of table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error())
+ end
+ else
+ -- table exists, let's check it.
+ local dfields = {}
+ local i
+ local any_common = false
+
+ for i = 1, db._.conn:NumRows() do
+ dfields[i] = db._.conn:FetchRow()
+ if fields[dfields[i].Field] then
+ any_common = true
+ end
+ end
+
+ if not any_common then
+ -- strictly no field in common; drop the table and recursively call opentable in order to proceed with a createtable instead.
+ q = "DROP TABLE " .. tname .. ";"
+ if db:SafeQuery(q) ~= 0 then
+ error("Error dropping table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q)
+ end
+ return db:opentable(tablename, ddl)
+ end
+
+ local d, k, v = {}
+
+ for k, v in pairs(dfields) do
+ d[v.Field] = v
+ if not fields[v.Field] then
+ q = "ALTER TABLE " .. tname .. " DROP COLUMN `" .. db.sql_escape(v.Field) .. "`;"
+ if db:SafeQuery(q) ~= 0 then
+ error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q)
+ end
+ end
+ end
+ dfields, d = d, nil
+
+ local deffered_alters = {}
+
+ for k, v in pairs(fields) do
+ q = nil
+ if not dfields[k] then
+ q = "ALTER TABLE " .. tname .. " ADD `" .. db.sql_escape(k) .. "` " .. v .. ";"
+ else
+ local identicals, alters = _luadb.compare_fields(ddl, dfields[k])
+ if not identicals then
+ q = "ALTER TABLE " .. tname .. " MODIFY `" .. db.sql_escape(k) .. "` " .. v .. ";"
+ local _, a
+ for _, a in pairs(alters) do
+ a.stmt = string.gsub(a.stmt, "@fieldname@", db.sql_escape(k))
+ if a.pri < 0 then
+ table.insert(deffered_alters, a)
+ else
+ q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";"
+ if db:SafeQuery(q) ~= 0 then
+ error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q)
+ end
+ end
+ end
+ end
+ end
+ if q then
+ if db:SafeQuery(q) ~= 0 then
+ error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q)
+ end
+ end
+ end
+
+ table.sort(deffered_alters, function(a, b) return a.pri < b.pri end)
+
+ for k, v in pairs(deffered_alters) do
+ q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";"
+ if db:SafeQuery(q) ~= 0 then
+ error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q)
+ end
+ end
+ end
+
+ return {
+ insert = _luadb.insert,
+ delete = _luadb.delete,
+ restrict = _luadb.restrict,
+ restricter = _luadb.restricter,
+ _ = {
+ tablename = tablename,
+ tname = tname,
+ db = db,
+ ddl = ddl,
+ conditions = "1",
+ },
+ }
+ end,
+
+ compare_fields = function(ddl, field)
+ local f = field.Field
+ local identical = true
+ local alters = {}
+ if type(ddl[f]) ~= "table" then
+ error("Internal issue with the DDL: key " .. f .. " isn't a table.")
+ end
+ -- check base type first.
+ local ddltype = _luadb.get_canon_type(ddl[f].type, f)
+ local desctype, desclength, descdecim = field.Type:match "(%w+)%((%d+)%.(%d+)%)"
+ if desctype == nil then
+ desctype, desclength = field.Type:match "(%w+)%((%d+)%)"
+ end
+ if desctype == nil then
+ error("Error parsing type string from database: " .. f .. ": " .. field.Type)
+ end
+ desctype = _luadb.get_canon_type(desctype, f)
+
+ if ddltype ~= desctype then
+ identical = false
+ end
+
+ if field.length and field.length ~= desclength then
+ identical = false
+ end
+
+ if field.decimal and field.decimal ~= descdecim then
+ identical = false
+ end
+
+ local options = ddl[f].options and _luadb.get_options(ddl[f].options) or {}
+
+ if options.PRI and field.Key ~= "PRI" then
+ -- needs to add a primary key here, with a specific ALTER TABLE statement
+ identical = false
+ table.insert(alters, { stmt = "ADD PRIMARY KEY(@fieldname@)", pri = -2 } )
+ end
+
+ if not options.PRI and field.Key == "PRI" then
+ -- needs to drop the primary key here.
+ identical = false
+ table.insert(alters, { stmt = "DROP PRIMARY KEY", pri = -1 } )
+ end
+
+ if options.UNI and field.Key ~= "UNI" then
+ -- needs to add a unique index
+ identical = false
+ table.insert(alters, { stmt = "ADD UNIQUE(@fieldname@)", pri = 0 } )
+ end
+
+ if not options.UNI and field.Key == "UNI" then
+ -- needs to drop the unique index
+ identical = false
+ table.insert(alters, { stmt = "DROP INDEX @fieldname@", pri = 0 } )
+ end
+
+ if not options.AUTO and field.Extra == "auto_increment" then
+ identical = false
+ end
+
+ if options.AUTO and field.Extra ~= "auto_increment" then
+ identical = false
+ end
+
+ return identical, alters
+ end,
+
+ insert = function(t, data)
+ local k, v, stmt, stmt2
+ local first = true
+
+ stmt = "INSERT INTO " .. t._.tname .. " ("
+ stmt2 = ") VALUES ("
+
+ for k, v in pairs(data) do
+ if not t._.ddl[k] then
+ error("Trying to address a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename)
+ end
+ if not first then
+ stmt = stmt .. ", "
+ stmt2 = stmt2 .. ", "
+ else
+ first = false
+ end
+ stmt = stmt .. "`" .. k .. "`"
+ stmt2 = stmt2 .. "'" .. v .. "'"
+ end
+
+ if t._.db:SafeQuery(stmt .. stmt2 .. ");") ~= 0 then
+ error("Error inserting a row inside table " .. t._.tablename .. ": " .. t._.db:Error())
+ end
+
+ return t._.db._.conn:InsertId()
+ end,
+
+ delete = function(t)
+ local stmt = "DELETE FROM " .. t._.tname .. " WHERE " .. t._.conditions .. ";"
+
+ if t._.db:SafeQuery(stmt) ~= 0 then
+ error("Error deleting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error())
+ end
+ end,
+
+ update = function(t, data)
+ local stmt = "UPDATE " .. t._.tname .. " SET "
+ local first = true
+ local k, v
+
+ for k, v in pairs(data) do
+ if t._.ddl[k] then
+ error("Trying to update a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename)
+ end
+ if not first then
+ stmt = stmt .. ", "
+ else
+ first = false
+ end
+ stmt = stmt .. "`" .. k .. "`="
+ if type(v) == "string" then
+ stmt = stmt .. '"' .. t._.db.sql_escape(v) .. '"'
+ else
+ error("Complex UPDATE queries are not supported yet.")
+ end
+ end
+
+ stmt = stmt .. " WHERE " .. t._.conditions
+
+ if t._.db:SafeQuery(stmt) ~= 0 then
+ error("Error updating row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error())
+ end
+ end,
+
+ select = function(t, fields)
+ local stmt = "SELECT "
+ local first = true
+ local k, v
+
+ for k, v in pairs(fields) do
+ if t._.ddl[v] then
+ error("Trying to select a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename)
+ end
+ if not first then
+ stmt = stmt .. ", "
+ else
+ first = false
+ end
+ stmt = stmt .. "`" .. v .. "`"
+ end
+
+ stmt = stmt .. " FROM " .. t._.tname .. " WHERE " .. t._.conditions
+
+ if t._.db:SafeQuery(stmt) ~= 0 then
+ error("Error selecting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error())
+ end
+
+ return _luadb.nextrow, t
+ end,
+
+ nextrow = function(t)
+ return t._.db._.conn:FetchRow()
+ end,
+
+ _restricts = {
+ -- All the rendering is done on-the-fly.
+ -- this is the leaf function; thus this is the only one which ough to get the sql_escape.
+ dop = function(r, field, expr, op)
+ if not r._.tbl._.ddl[field] then
+ error("Trying to restrict on a field which doesn't exist: " .. field .. " - inside of table " .. r._.tbl._.tablename)
+ end
+
+ if type(expr) == "string" or type(expr) == "number" then
+ return "`" .. field .. "` " .. op .. ' "' .. t._.db.sql_escape(expr) .. '"'
+ elseif type(expr) == "table" then
+ error("Complex queries not handled for now")
+ else
+ error("Wrong type for expression in operator: " .. type(expr))
+ end
+ end,
+
+ varops = function(r, op, ...)
+ local exprs, first, r, k, v = {...}, true
+
+ for k, v in pairs(exprs) do
+ if not first then
+ r = r .. " " .. op .. " "
+ end
+ r = r .. v
+ first = false
+ end
+ end,
+ },
+
+ restricts = {
+ NOT = function(r, expr)
+ return "NOT " .. expr
+ end,
+
+ AND = function(r, ...) return _luadb._restricts.varops(r, "AND", ...) end,
+ OR = function(r, ...) return _luadb._restricts.varops(r, "OR", ...) end,
+
+ EQ = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "=") end,
+ NEQ = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "!=") end,
+ GT = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, ">") end,
+ LT = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "<") end,
+ GE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, ">=") end,
+ LE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "<=") end,
+ LIKE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "LIKE") end,
+ IN = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "IN") end,
+ },
+
+ restricter = function(t)
+ local r, k, v = { _ = { tbl = t } }
+
+ for k, v in pairs(_luadb.restricts) do
+ r[k] = v
+ end
+
+ return r
+ end,
+
+ restrict = function(t, conditions)
+ -- all in all, "conditions" can just be a direct string. Let the programmer do what he wants, after all.
+ if conditions == nil or conditions == "" then conditions = "1" end
+ t._.conditions = conditions
+ end,
+}
+
+luadb = {
+ SafeQuery = function(db, str)
+ local r = db._.conn:Query(str)
+ if luadb.debugmode then
+ print(str)
+ end
+
+ if r ~= 0 and db._.conn:ErrNO() == 2006 then -- disconnected
+ db._.conn = db:opendb()
+ r = db._.conn:Query(str)
+ end
+
+ return r
+ end,
+
+ opendb = function(id, user, password, base, prefix)
+ local db
+
+-- if user == nil and password == nil and base == nil then
+ -- get into a registry to fetch "id"
+-- else
+-- end
+ if type(id) == "table" and type(id._) == "table" then
+ user = id._.user
+ password = id._.password
+ base = id._.base
+ prefix = id._.prefix
+ id = id._.id
+ end
+ db = SQLConnection(id, user, password, base)
+ return {
+ opentable = _luadb.opentable,
+ SafeQuery = luadb.SafeQuery,
+ sql_escape = sql_escape,
+ ErrNO = function(db) return db._.conn:ErrNO() end,
+ Error = function(db) return db._.conn:Error() end,
+ _ = {
+ conn = db,
+ id = id,
+ user = user,
+ password = password,
+ base = base,
+ prefix = sql_escape(prefix or ""),
+ },
+ }
+ end,
+}