--[[ /* * Baltisot * Copyright (C) 1999-2008 Nicolas "Pixel" Noble * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */ ]]-- -- dblib - database object layer -- Currently built-in with MySQL - TODO: support some abstraction layer with drivers. --[[ ddl = { field1 = { type = "varchar", length = 32 }, field2 = { type = "float", length = 11, decimal = 8, default = 42 }, id = { options = "pri,auto" }, } db = luadb.opendb(host, login, password, database) t = db:opentable("foobar", ddl) t:empty() x = t:insert{ field1 = "test1" } -- returns 1 x = t:insert{ field1 = "test2" } -- returns 2 x = t:insert{ field1 = "test3" } -- returns 3 x = t:insert{ field1 = "test4", field2 = 18 } -- returns 4 x = t:insert{ field1 = "test5", field2 = 20 } -- returns 5 r = t:restricter() t:restrict(r:EQ("id", 3)) -- WHERE `id` = "3" x = t:update{ field1 = "3tset" } -- returns 1 t:restrict(r:EQ("field2", 42)) -- WHERE `field2` = "42" x = t:update{ field2 = 28 } -- returns 3 t:restrict(r:EQ("id", 4)) -- WHERE `id` = "4" x = t:delete() -- returns 1 t:restrict(r:AND(r:LIKE("field1", "test%"), r:NEQ("field2", 20))) -- WHERE `field` LIKE "test%" AND `field2` != 20 for r in t:gselect{ "field1" } do print(r.field1) end -- prints the strings "test1" and "test2" t:restrict(r:LIKE("field1", "test%")) print(t:count()) -- prints "3" t:restrict() -- WHERE 1 t:select() print(t:numrows()) -- prints "4" print(t:numfields()) -- prints "3" row = t:nextrow() print(row.field1) -- prints "test1" ]]-- local _luadb _luadb = { get_options = function(str) local s = split(str, ",") local r = {} for k, v in pairs(s) do r[v:upper()] = true end return r end, get_canon_type = function(ttype, k) if ttype == nil then return "INT" elseif type(ttype) ~= "string" then error("Wrong data in ddl - " .. k .. ".type isn't a string.") elseif ttype:upper() == "VARCHAR" then return "VARCHAR" elseif ttype:upper() == "INT" then return "INT" elseif ttype:upper() == "BLOB" then return "BLOB" elseif ttype:upper() == "DATETIME" then return "DATETIME" elseif ttype:upper() == "FLOAT" then return "FLOAT" else error("Unknow data type in ddl - " .. k .. ": " .. ttype:upper()) end end, generate_fields = function(db, ddl) local k, v local r, r2 = {}, {} for k, v in pairs(ddl) do r[k] = "" r2[k] = {} r[k] = _luadb.get_canon_type(v.type, k) if v.length == nil then -- do nothing elseif type(v.length) ~= "number" then error("Wrong data in ddl - " .. k .. ".length isn't a number.") elseif v.decimal == nil then r[k] = r[k] .. "(" .. v.length .. ")" elseif type(v.decimal) ~= "number" then error("Wrong data in ddl - " .. k .. ".decimal isn't a number.") else r[k] = r[k] .. "(" .. v.length .. "," .. v.decimal .. ")" end if v.options ~= nil and type(v.options) ~= "string" then error("Wrong data in ddl - " .. k .. ".options isn't a string.") end local options = v.options and _luadb.get_options(v.options) or {} if options.NULL then r[k] = r[k] .. " NULL" else r[k] = r[k] .. " NOT NULL" end if options.PRI then r[k] = r[k] .. " PRIMARY KEY" end if options.UNIQ then table.insert(r2[k], "ADD UNIQUE (`@fieldname@`)") end if options.AUTO then r[k] = r[k] .. " AUTO_INCREMENT" end if v.default then if type(v.default) ~= "string" and type(v.default) ~= "number" then error("Default value for field " .. k .. " isn't usable.") end r[k] = r[k] .. ' DEFAULT "' .. db.sql_escape(v.default) .. '"' end end return r, r2 end, opentable = function(db, tablename, ddl) local fields, alters = _luadb.generate_fields(db, ddl) local tname = "`" .. db._.prefix .. db.sql_escape(tablename) .. "`" local operations = 0 if db:SafeQuery("DESC " .. tname) ~= 0 then -- table doesn't exist, create it if db._.conn:ErrNO() == 1146 then local q = "CREATE TABLE " .. tname .. " (" local k, v, first first = true operations = -1 for k, v in pairs(fields) do if not first then q = q .. ", " else first = false end q = q .. "`" .. db.sql_escape(k) .. "` " .. v end q = q .. ");" if db:SafeQuery(q) ~= 0 then error("Error creating table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) end for k, v in pairs(alters) do local _, stmt for _, stmt in pairs(v) do q = "ALTER TABLE " .. tname .. " " .. string.gsub(stmt, "@fieldname@", db.sql_escape(k)) if db:SafeQuery(q) ~= 0 then error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) db:SafeQuery("DROP TABLE " .. tname ";") end end end else error("Error getting description of table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error()) end else -- table exists, let's check it. local dfields = {} local i local any_common = false for i = 1, db._.conn:NumRows() do dfields[i] = db._.conn:FetchRow() if fields[dfields[i].Field] then any_common = true end end if not any_common then -- strictly no field in common; drop the table and recursively call opentable in order to proceed with a createtable instead. q = "DROP TABLE " .. tname .. ";" if db:SafeQuery(q) ~= 0 then error("Error dropping table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) end return db:opentable(tablename, ddl) end local d, k, v = {} for k, v in pairs(dfields) do d[v.Field] = v if not fields[v.Field] then q = "ALTER TABLE " .. tname .. " DROP COLUMN `" .. db.sql_escape(v.Field) .. "`;" if db:SafeQuery(q) ~= 0 then error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) end operations = operations + 1 end end dfields, d = d, nil local deffered_alters = {} for k, v in pairs(fields) do q = nil if not dfields[k] then q = "ALTER TABLE " .. tname .. " ADD `" .. db.sql_escape(k) .. "` " .. v .. ";" else local identicals, alters = _luadb.compare_fields(ddl, dfields[k]) if not identicals then q = "ALTER TABLE " .. tname .. " MODIFY `" .. db.sql_escape(k) .. "` " .. v .. ";" local _, a for _, a in pairs(alters) do a.stmt = string.gsub(a.stmt, "@fieldname@", db.sql_escape(k)) if a.pri < 0 then table.insert(deffered_alters, a) else q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";" if db:SafeQuery(q) ~= 0 then error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) end operations = operations + 1 end end end end if q then if db:SafeQuery(q) ~= 0 then error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) end operations = operations + 1 end end table.sort(deffered_alters, function(a, b) return a.pri < b.pri end) for k, v in pairs(deffered_alters) do q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";" if db:SafeQuery(q) ~= 0 then error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) end operations = operations + 1 end end return { insert = _luadb.insert, delete = _luadb.delete, select = _luadb.select, gselect = _luadb.gselect, numrows = _luadb.numrows, numfields = _luadb.numfields, nextrow = _luadb.nextrow, count = _luadb.count, update = _luadb.update, restrict = _luadb.restrict, restricter = _luadb.restricter, empty = _luadb.empty, _ = { tablename = tablename, tname = tname, db = db, ddl = ddl, conditions = "1", }, }, operations end, compare_fields = function(ddl, field) local f = field.Field local identical = true local alters = {} if type(ddl[f]) ~= "table" then error("Internal issue with the DDL: key " .. f .. " isn't a table.") end -- check base type first. local ddltype = _luadb.get_canon_type(ddl[f].type, f) local desctype, desclength, descdecim = field.Type:match "(%w+)%((%d+)%,(%d+)%)" if desctype == nil then desctype, desclength = field.Type:match "(%w+)%((%d+)%)" end if desctype == nil then error("Error parsing type string from database: " .. f .. ": " .. field.Type) end desctype = _luadb.get_canon_type(desctype, f) if ddltype ~= desctype then identical = false end if field.length and field.length ~= desclength then identical = false end if field.decimal and field.decimal ~= descdecim then identical = false end local options = ddl[f].options and _luadb.get_options(ddl[f].options) or {} if options.PRI and field.Key ~= "PRI" then -- needs to add a primary key here, with a specific ALTER TABLE statement identical = false table.insert(alters, { stmt = "ADD PRIMARY KEY(@fieldname@)", pri = -2 } ) end if not options.PRI and field.Key == "PRI" then -- needs to drop the primary key here. identical = false table.insert(alters, { stmt = "DROP PRIMARY KEY", pri = -1 } ) end if options.UNI and field.Key ~= "UNI" then -- needs to add a unique index identical = false table.insert(alters, { stmt = "ADD UNIQUE(@fieldname@)", pri = 0 } ) end if not options.UNI and field.Key == "UNI" then -- needs to drop the unique index identical = false table.insert(alters, { stmt = "DROP INDEX @fieldname@", pri = 0 } ) end if not options.AUTO and field.Extra == "auto_increment" then identical = false end if options.AUTO and field.Extra ~= "auto_increment" then identical = false end return identical, alters end, insert = function(t, data) local k, v, stmt, stmt2 local first = true stmt = "INSERT INTO " .. t._.tname .. " (" stmt2 = ") VALUES (" for k, v in pairs(data) do if not t._.ddl[k] then error("Trying to address a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) end if not first then stmt = stmt .. ", " stmt2 = stmt2 .. ", " else first = false end stmt = stmt .. "`" .. k .. "`" stmt2 = stmt2 .. "'" .. v .. "'" end if t._.db:SafeQuery(stmt .. stmt2 .. ");") ~= 0 then error("Error inserting a row inside table " .. t._.tablename .. ": " .. t._.db:Error()) end return t._.db._.conn:InsertId() end, delete = function(t) local stmt = "DELETE FROM " .. t._.tname .. " WHERE " .. t._.conditions .. ";" if t._.db:SafeQuery(stmt) ~= 0 then error("Error deleting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) end return t._.db._.conn:NumAffectedRows() end, update = function(t, data) local stmt = "UPDATE " .. t._.tname .. " SET " local first = true local k, v for k, v in pairs(data) do if not t._.ddl[k] then error("Trying to update a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) end if not first then stmt = stmt .. ", " else first = false end stmt = stmt .. "`" .. k .. "`=" if type(v) == "string" or type(v) == "number" then stmt = stmt .. '"' .. t._.db.sql_escape(v) .. '"' else error("Complex UPDATE queries are not supported yet.") end end stmt = stmt .. " WHERE " .. t._.conditions if t._.db:SafeQuery(stmt) ~= 0 then error("Error updating row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) end return t._.db._.conn:NumAffectedRows() end, iselect = function(t, fields, bypass) local stmt = "SELECT " local first = true local k, v if fields == nil then stmt = stmt .. "*" else for k, v in pairs(fields) do if not bypass and not t._.ddl[v] then error("Trying to select a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) end if not first then stmt = stmt .. ", " else first = false end if bypass then stmt = stmt .. v else stmt = stmt .. "`" .. v .. "`" end end end stmt = stmt .. " FROM " .. t._.tname .. " WHERE " .. t._.conditions if t._.db:SafeQuery(stmt) ~= 0 then error("Error selecting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) end return t:numrows(), t:numfields() end, empty = function(t, fields) if t._.db:SafeQuery("TRUNCATE TABLE " .. t._.tname) ~= 0 then error("Error emptying table " .. t._.tablename .. ": " .. t._.db:Error()) end end, select = function(t, fields) return _luadb.iselect(t, fields) end, numrows = function(t) return t._.db._.conn:NumRows() end, numfields = function(t) return t._.db._.conn:NumFields() end, count = function(t) _luadb.iselect(t, {"COUNT(*) AS count"}, true) return t:nextrow().count end, gselect = function(t, fields) _luadb.iselect(t, fields) return t.nextrow, t end, nextrow = function(t) return t._.db._.conn:FetchRow() end, _restricts = { -- All the rendering is done on-the-fly. -- this is the leaf function; thus this is the only one which ough to get the sql_escape. dop = function(r, field, expr, op) if not r._.tbl._.ddl[field] then error("Trying to restrict on a field which doesn't exist: " .. field .. " - inside of table " .. r._.tbl._.tablename) end if type(expr) == "string" or type(expr) == "number" then return "`" .. field .. "` " .. op .. ' "' .. t._.db.sql_escape(expr) .. '"' elseif type(expr) == "table" then error("Complex queries not handled for now") else error("Wrong type for expression in operator: " .. type(expr)) end end, varops = function(r, op, ...) local exprs, first, r, k, v = {...}, true, "" for k, v in pairs(exprs) do if not first then r = r .. " " .. op .. " " end r = r .. v first = false end return r end, }, restricts = { NOT = function(r, expr) return "NOT " .. expr end, AND = function(r, ...) return _luadb._restricts.varops(r, "AND", ...) end, OR = function(r, ...) return _luadb._restricts.varops(r, "OR", ...) end, EQ = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "=") end, NEQ = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "!=") end, GT = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, ">") end, LT = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "<") end, GE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, ">=") end, LE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "<=") end, LIKE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "LIKE") end, IN = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "IN") end, }, restricter = function(t) local r, k, v = { _ = { tbl = t } } for k, v in pairs(_luadb.restricts) do r[k] = v end return r end, restrict = function(t, conditions) -- all in all, "conditions" can just be a direct string. Let the programmer do what he wants, after all. if conditions == nil or conditions == "" then conditions = "1" end t._.conditions = conditions end, SafeQuery = function(db, str) local r = db._.conn:Query(str) if luadb.debugmode then print(str) end if r ~= 0 and db._.conn:ErrNO() == 2006 then -- disconnected db._.conn = db:opendb() r = db._.conn:Query(str) end return r end, } luadb = { opendb = function(id, user, password, base, prefix) local db -- if user == nil and password == nil and base == nil then -- get into a registry to fetch "id" -- else -- end if type(id) == "table" and type(id._) == "table" then user = id._.user password = id._.password base = id._.base prefix = id._.prefix id = id._.id end db = SQLConnection(id, user, password, base) return { opentable = _luadb.opentable, SafeQuery = _luadb.SafeQuery, sql_escape = sql_escape, ErrNO = function(db) return db._.conn:ErrNO() end, Error = function(db) return db._.conn:Error() end, _ = { conn = db, id = id, user = user, password = password, base = base, prefix = sql_escape(prefix or ""), }, } end, }