--[[ /* * Baltisot * Copyright (C) 1999-2008 Nicolas "Pixel" Noble * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA */ ]]-- -- dblib - database object layer -- Currently built-in with MySQL and Oracle support. MySQL is the default one. --[[ ddl = { field1 = { type = "varchar", length = 32 }, field2 = { type = "number", length = 11, decimal = 8, default = 42 }, id = { options = "pri,auto" }, } db = luadb.mysql.opendb(host, login, password, database) t = db:opentable("foobar", ddl) t:empty() x = t:insert{ field1 = "test1" } -- returns 1 x = t:insert{ field1 = "test2" } -- returns 2 x = t:insert{ field1 = "test3" } -- returns 3 x = t:insert{ field1 = "test4", field2 = 18 } -- returns 4 x = t:insert{ field1 = "test5", field2 = 20 } -- returns 5 r = t:restricter() t:restrict(r:EQ("id", 3)) -- WHERE `id` = "3" x = t:update{ field1 = "3tset" } -- returns 1 t:restrict(r:EQ("field2", 42)) -- WHERE `field2` = "42" x = t:update{ field2 = 28 } -- returns 3 t:restrict(r:EQ("id", 4)) -- WHERE `id` = "4" x = t:delete() -- returns 1 t:restrict(r:AND(r:LIKE("field1", "test%"), r:NEQ("field2", 20))) -- WHERE `field` LIKE "test%" AND `field2` != 20 for r in t:gselect{ "field1" } do print(r.field1) end -- prints the strings "test1" and "test2" t:restrict(r:LIKE("field1", "test%")) print(t:count()) -- prints "3" t:restrict() -- WHERE 1 t:select() print(t:numrows()) -- prints "4" print(t:numfields()) -- prints "3" row = t:nextrow() print(row.field1) -- prints "test1" ]]-- _luadb = { get_options = function(str) local s = split(str, ",") local r = {} for k, v in pairs(s) do r[v:upper()] = true end return r end, get_canon_type = function(ttype, k, db) if ttype == nil then return "INT" elseif type(ttype) ~= "string" then error("Wrong data in ddl - " .. k .. ".type isn't a string.") elseif ttype:upper() == "VARCHAR" then return "VARCHAR" elseif ttype:upper() == "INT" then return "INT" elseif ttype:upper() == "BLOB" then return "BLOB" elseif ttype:upper() == "TIMESTAMP" then return "TIMESTAMP" elseif ttype:upper() == "NUMBER" then if db._.odbc_name == "mysql" then return "FLOAT" end return "NUMBER" elseif ttype:upper() == "FLOAT" then return "NUMBER" elseif ttype:upper() == "TEXT" then return "TEXT" else error("Unknow data type in ddl - " .. k .. ": " .. ttype:upper()) end end, generate_fields = function(db, tablename, ddl) local k, v local r, alters, keys, extras = {}, {}, {}, {} for k, v in pairs(ddl) do local ttype r[k] = "" alters[k] = {} ttype = _luadb.get_canon_type(v.type, k, db) if ttype == "TEXT" and db._.odbc_name == "oracle" then ttype = "CLOB" end r[k] = ttype if v.length == nil then -- do nothing elseif type(v.length) ~= "number" then error("Wrong data in ddl - " .. k .. ".length isn't a number.") elseif v.decimal == nil then r[k] = r[k] .. "(" .. v.length .. ")" elseif type(v.decimal) ~= "number" then error("Wrong data in ddl - " .. k .. ".decimal isn't a number.") else r[k] = r[k] .. "(" .. v.length .. "," .. v.decimal .. ")" end if v.options ~= nil and type(v.options) ~= "string" then error("Wrong data in ddl for table " .. tablename .. " - " .. k .. ".options isn't a string.") end local options = v.options and _luadb.get_options(v.options) or {} if options.PRI then table.insert(keys, k) keys[k] = true end if options.UNIQ then table.insert(alters[k], "ADD UNIQUE (" .. db._.fq .. "@fieldname@" .. db._.fq .. ")") end if options.AUTO then if db._.odbc_name == "mysql" then r[k] = r[k] .. " AUTO_INCREMENT" elseif db._.odbc_name == "oracle" then local seq = "seq_" .. db.sql_escape(tablename .. "_" .. k) local trg = "trg_" .. db.sql_escape(tablename .. "_" .. k) .. "_auto" local owner = db._.prefix == "" and "USER()" or ("'" .. db._.prefix .. "'") local tbl = db.sql_escape(tablename) local fld = db.sql_escape(k) table.insert(extras, [[ DECLARE P_EXISTS NUMBER; BEGIN SELECT COUNT(*) INTO P_EXISTS FROM ALL_SEQUENCES WHERE SEQUENCE_NAME=']] .. seq .. [[' AND SEQUENCE_OWNER=]] .. owner .. [[; IF P_EXISTS = 0 THEN EXECUTE IMMEDIATE 'CREATE SEQUENCE "]] .. seq .. [["'; END IF; END; ]]) table.insert(extras, [[ CREATE OR REPLACE TRIGGER "]] .. trg .. [[" BEFORE INSERT ON "]] .. tbl .. [[" FOR EACH ROW BEGIN IF :new."]] .. fld .. [[" IS NULL THEN SELECT "]] .. seq .. [[".nextval INTO :new."]] .. fld .. [[" FROM DUAL; END IF; END; ]]) else error("Unknown odbc for auto increment operations.") end end if v.default then if type(v.default) ~= "string" and type(v.default) ~= "number" then error("Default value for field " .. k .. " isn't usable.") end r[k] = r[k] .. ' DEFAULT ' .. db._.eq .. db.sql_escape(v.default) .. db._.eq elseif ttype == "TIMESTAMP" then r[k] = r[k] .. ' DEFAULT CURRENT_TIMESTAMP' end if options.NULL then r[k] = r[k] .. " NULL" else r[k] = r[k] .. " NOT NULL" end end return r, alters, keys, extras end, newtable = function(tablename, tname, db, ddl) return { insert = _luadb.insert, delete = _luadb.delete, select = _luadb.select, gselect = _luadb.gselect, rselect = _luadb.rselect, grselect = _luadb.grselect, update = _luadb.update, pinsert = function(...) return pcall(_luadb.insert, ...) end, pdelete = function(...) return pcall(_luadb.delete, ...) end, pselect = function(...) return pcall(_luadb.select, ...) end, pgselect = function(...) return pcall(_luadb.gselect, ...) end, prselect = function(...) return pcall(_luadb.rselect, ...) end, pgrselect = function(...) return pcall(_luadb.grselect, ...) end, pupdate = function(...) return pcall(_luadb.update, ...) end, numrows = _luadb.numrows, numfields = _luadb.numfields, nextrow = _luadb.nextrow, count = _luadb.count, restrict = _luadb.restrict, restricter = _luadb.restricter, empty = _luadb.empty, _ = { tablename = tablename, tname = tname, db = db, ddl = ddl, conditions = "1=1", }, } end, opentable = function(db, tablename, ddl, force_create, read_only) local tname if db._.prefix and db._.prefix ~= "" and db._.odbc_name == "oracle" then tname = db._.fq .. db._.prefix .. db._.fq .. '.' .. db._.fq .. db.sql_escape(tablename) .. db._.fq else tname = db._.fq .. db._.prefix .. db.sql_escape(tablename) .. db._.fq end if not ddl then return _luadb.newtable(tablename, tname, db, ddl) end local fields, alters, keys, extras = _luadb.generate_fields(db, tablename, ddl) local operations = 0 local dfields = db:Desc(tablename) if force_create or dfields == nil and not read_only then -- table doesn't exist, create it local q = "CREATE TABLE " .. tname .. " (" local k, v, first first = true operations = -1 for k, v in pairs(fields) do if not first then q = q .. ", " else first = false end q = q .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v end if #keys ~= 0 then q = q .. ", PRIMARY KEY (" first = true for k, v in ipairs(keys) do if not first then q = q .. ", " else first = false end q = q .. db._.fq .. db.sql_escape(v) .. db._.fq end q = q .. ")" end q = q .. ")" if db._.odbc_name == "mysql" then q = q .. " ENGINE=InnoDB" end q = q .. ";" if db:SafeQuery(q, "update") ~= 0 then error("Error creating table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end for k, v in pairs(alters) do local _, stmt for _, stmt in pairs(v) do q = "ALTER TABLE " .. tname .. " " .. string.gsub(stmt, "@fieldname@", db.sql_escape(k)) if db:SafeQuery(q, "update") ~= 0 then error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) db:SafeQuery("DROP TABLE " .. tname .. ";", "update") end end end for k, v in pairs(extras) do if db:SafeQuery(v, "raw") ~= 0 then error("Error running extra query for " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. v) db:SafeQuery("DROP TABLE " .. tname .. ";", "update") end end else if not read_only -- table exists, let's check it. local i local any_common = false for i = 1, #dfields do if fields[dfields[i].Field] then any_common = true end end if not any_common then -- strictly no field in common; drop the table and tail-call opentable in order to proceed with a createtable instead. q = "DROP TABLE " .. tname .. ";" if db:SafeQuery(q, "update") ~= 0 then error("Error dropping table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end return db:opentable(tablename, ddl, true) end local d, k, v = {} for k, v in pairs(dfields) do d[v.Field] = v if not fields[v.Field] then q = "ALTER TABLE " .. tname .. " DROP COLUMN " .. db._.fq .. db.sql_escape(v.Field) .. db._.fq .. ";" if db:SafeQuery(q, "update") ~= 0 then error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end end dfields, d = d, nil local deffered_alters = {} for k, v in pairs(fields) do q = nil if not dfields[k] then q = "ALTER TABLE " .. tname .. " ADD " .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v .. ";" else local identicals, alters = _luadb.compare_fields(db, ddl, dfields[k]) if not identicals then q = "ALTER TABLE " .. tname .. " MODIFY " .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v .. ";" local _, a for _, a in pairs(alters) do a.stmt = string.gsub(a.stmt, "@fieldname@", db.sql_escape(k)) if a.pri < 0 then table.insert(deffered_alters, a) else q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";" if db:SafeQuery(q, "update") ~= 0 then error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end end end end if q then if db:SafeQuery(q, "update") ~= 0 then error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end end table.sort(deffered_alters, function(a, b) return a.pri < b.pri end) for k, v in pairs(deffered_alters) do q = "ALTER TABLE " .. tname .. " " .. v.stmt .. ";" if db:SafeQuery(q, "update") ~= 0 then error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end end return newtable(tablename, tname, db, ddl), operations end, compare_fields = function(db, ddl, field) local f = field.Field local identical = true local alters = {} if type(ddl[f]) ~= "table" then error("Internal issue with the DDL: key " .. f .. " isn't a table.") end -- check base type first. local ddltype = _luadb.get_canon_type(ddl[f].type, f, db) local desctype, desclength, descdecim = field.Type:match "(%w+)%((%d+)%,(%d+)%)" if desctype == nil then desctype, desclength = field.Type:match "(%w+)%((%d+)%)" end if desctype == nil then desctype = field.Type:match "(%w+)" end if desctype == nil then error("Error parsing type string from database: " .. f .. ": " .. field.Type) end desctype = _luadb.get_canon_type(desctype, f, db) if ddltype ~= desctype then identical = false end if field.length and field.length ~= desclength then identical = false end if field.decimal and field.decimal ~= descdecim then identical = false end local options = ddl[f].options and _luadb.get_options(ddl[f].options) or {} if options.PRI and field.Key ~= "PRI" then -- needs to add a primary key here, with a specific ALTER TABLE statement identical = false table.insert(alters, { stmt = "ADD PRIMARY KEY(@fieldname@)", pri = -1 } ) end if not options.PRI and field.Key == "PRI" then -- needs to drop the primary key here. identical = false table.insert(alters, { stmt = "DROP PRIMARY KEY", pri = -2 } ) end -- Needs oracle version of this... if db._.odbc_name == "mysql" then if options.UNI and field.Key ~= "UNI" then -- needs to add a unique index identical = false table.insert(alters, { stmt = "ADD UNIQUE(@fieldname@)", pri = 0 } ) end if not options.UNI and field.Key == "UNI" then -- needs to drop the unique index identical = false table.insert(alters, { stmt = "DROP INDEX @fieldname@", pri = 0 } ) end end -- Need oracle version of this check. if db._.odbc_name == "mysql" then if not options.AUTO and field.Extra == "auto_increment" then identical = false end end if db._.odbc_name == "mysql" then if options.AUTO and field.Extra ~= "auto_increment" then identical = false end end return identical, alters end, insert = function(t, data) local k, v, stmt, stmt2 local first = true local got_blobs = false stmt = "INSERT INTO " .. t._.tname .. " (" stmt2 = ") VALUES (" for k, v in pairs(data) do if not t._.ddl[k] then error("Trying to address a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) end local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) if t._.db._.odbc_name == "oracle" and (ftype == "BLOB" or ftype == "TEXT") then got_blobs = true else if not first then stmt = stmt .. ", " stmt2 = stmt2 .. ", " else first = false end stmt = stmt .. t._.db._.fq .. k .. t._.db._.fq stmt2 = stmt2 .. t._.db._.eq .. v .. t._.db._.eq end end if t._.db._.odbc_name == "oracle" then local conn, stmt, status = t._.db._.conn stmt = conn:createStatement(stmt .. stmt2 .. ") RETURNING ROWIDTOCHAR(rowid) INTO :1") stmt:registerOutParam(1, OCCISTRING, 30, "VARCHAR2") status = stmt:executeUpdate() if status ~= 1 then error("Error inserting a row inside table " .. t._.tablename .. ": " .. stmt:getErrorMsg()) end t._.db._.lastId = stmt:getString(1) conn:terminateStatement(stmt) else if t._.db:SafeQuery(stmt .. stmt2 .. ");", "update") ~= 0 then error("Error inserting a row inside table " .. t._.tablename .. ": " .. t._.db:Error()) end end if t._.db._.odbc_name == "mysql" then t._.db._.lastId = t._.db._.conn:InsertId() end return t._.db._.lastId end, delete = function(t) local stmt = "DELETE FROM " .. t._.tname .. " WHERE " .. t._.conditions .. ";" if t._.db:SafeQuery(stmt, "update") ~= 0 then error("Error deleting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) end if t._.db._.odbc_name == "mysql" then t._.db._.lastAffectedRows = t._.db._.conn:NumAffectedRows() elseif t._.db._.odbc_name == "oracle" then t._.db._.lastAffectedRows = t._.db._.affectedRows else error("Don't know how to handle odbc " .. t._.db._.odbc_name .. " in update.") end return t._.db._.lastAffectedRows end, update = function(t, data) local stmt = "UPDATE " .. t._.tname .. " SET " local first = true local got_blobs = false local k, v for k, v in pairs(data) do if not t._.ddl[k] then error("Trying to update a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) end local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) if t._.db._.odbc_name == "oracle" and (ftype == "BLOB" or ftype == "TEXT") then got_blobs = true else if not first then stmt = stmt .. ", " else first = false end stmt = stmt .. t._.db._.fq .. t._.db.sql_escape(k) .. t._.db._.fq .. "=" if type(v) == "string" or type(v) == "number" then stmt = stmt .. t._.db._.eq .. t._.db.sql_escape(v) .. t._.db._.eq else error("Complex UPDATE queries are not supported yet.") end end end stmt = stmt .. " WHERE " .. t._.conditions if t._.db:SafeQuery(stmt, "update", got_blobs) ~= 0 then error("Error updating row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) end if t._.db._.odbc_name == "mysql" then t._.db._.lastAffectedRows = t._.db._.conn:NumAffectedRows() elseif t._.db._.odbc_name == "oracle" then t._.db._.lastAffectedRows = t._.db._.affectedRows else error("Don't know how to handle odbc " .. t._.db._.odbc_name .. " in update.") end if got_blobs then -- very ugly way to solve this, but... *shrug* local conn, buffs, count, stmt, rset = t._.db._.conn, {}, 1 stmt = "SELECT " first = false for k, v in pairs(data) do local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) if ftype == "BLOB" or ftype == "TEXT" then if not first then stmt = stmt .. ", " else first = false end stmt = stmt .. t._.db._.fq .. t._.db.sql_escape(k) .. t._.db._.fq end buffs[count] = Buffer(true) buffs[count]:write(v) count = count + 1 end stmt = stmt .. " FROM " .. t._.tname .. " WHERE " .. t._.conditions .. " FOR UPDATE" stmt = conn:createStatement(stmt) rset = stmt:executeQuery() while rset:next() ~= 0 do for k, v in ipairs(buffs) do rset:setBlob(k, v) v:seek(0, SEEK_SET) end end for k, v in ipairs(buffs) do v:destroy() end conn:commit() stmt:closeResultSet(rset) conn:terminateStatement(stmt) end return t._.db._.lastAffectedRows end, iselect = function(bypass, t, fields, foreign, options) local stmt = "SELECT " local first = true local k, v if fields == nil then stmt = stmt .. "*" else for k, v in pairs(fields) do local is_foreign = false if foreign then local _, f for _, f in pairs(foreign) do if f._.ddl[v] then is_foreign = true end end end if not bypass and not is_foreign and not t._.ddl[v] then error("Trying to select a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) end if not first then stmt = stmt .. ", " else first = false end if bypass then stmt = stmt .. v else stmt = stmt .. t._.db._.fq .. v .. t._.db._.fq end end end if foreign then local foreign_conds = "1=1" stmt = stmt .. " FROM " .. t._.tname local found_table, fname, f for k, v in pairs(foreign) do stmt = stmt .. ", " .. v._.tname found_table = false for fname, f in pairs(t._.ddl) do if f.foreign and f.foreign.tablename == v._.tablename then if not v._.ddl[f.foreign.fieldname] then error("Foreign key points to an unknown field in table: " .. f.foreign.tablename .. "." .. f.foreign.fieldname) end foreign_conds = foreign_conds .. " AND " .. t._.tname .. "." .. t._.db._.fq .. fname .. t._.db._.fq "=" .. v._.tname .. "." .. t._.db._.fq .. f.foreign.fieldname .. t._.db._.fq found_table = true end end if not found_table then error("Foreign table " .. v._.tablename .. " specified in request didn't match any of the foreign key of table " .. t._.tablename) end end stmt = stmt .. " WHERE (" .. t._.conditions .. ") AND (" .. foreign_conds .. ")" else stmt = stmt .. " FROM " .. t._.tname .. " WHERE " .. t._.conditions end if options then if options.order then local direction = "ASC" if type(options.order) ~= "table" then error("Wrong type for options.order; should be a table.") end if type(options.order.by) ~= "string" then error("The order option has to contain one valid 'by' field.") end if options.order.dir then if type(options.order.dir) ~= "string" then error("options.order.dir has to be a string") end if options.order.dir:upper() == "ASC" then direction = "ASC" elseif options.order.dir:upper() == "DESC" then direction = "DESC" else error("options.order.dir has to be either 'asc' or 'desc'") end end stmt = stmt .. " ORDER BY " .. t._.db._.fq .. options.order.by .. t._.db._.fq .. " " .. direction end if options.limit then if type(options.limit) ~= "table" then error("Wrong type for options.limit; should be a table.") end local start, size = 0, 30 if options.limit.start then if type(options.limit.start) ~= "number" then error("options.limit.start has to be a number") end start = options.limit.start end if options.limit.size then if type(options.limit.size) ~= "number" then error("options.limit.size has to be a number") end size = options.limit.size end if t._.db._.odbc_name == "mysql" then stmt = stmt .. " LIMIT " .. start .. "," .. size elseif t._.db._.odbc_name == "oracle" then error("LIMIT code for Oracle not implemented yet.") else error("No LIMIT code - unknown odbc") end end end stmt = stmt .. ";" if t._.db:SafeQuery(stmt, "query") ~= 0 then error("Error selecting row(s) inside table " .. t._.tablename .. ": " .. t._.db:ErrorStr()) end return (t._.db._.odbc_name == "oracle" and -1 or t:numrows()), t:numfields() end, empty = function(t, fields) if t._.db:SafeQuery("TRUNCATE TABLE " .. t._.tname, "update") ~= 0 then error("Error emptying table " .. t._.tablename .. ": " .. t._.db:Error()) end end, select = function(...) return _luadb.iselect(false, ...) end, rselect = function(...) return _luadb.iselect(true, ...) end, numrows = function(t) if t._.db._.odbc_name == "mysql" then return t._.db._.conn:NumRows() elseif t._.db._.odbc_name == "oracle" then error "No 'numrows' for Oracle yet." else error "Unknown odbc for numrows." end end, numfields = function(t) if t._.db._.odbc_name == "mysql" then return t._.db._.conn:NumFields() elseif t._.db._.odbc_name == "oracle" then return t._.db._.numFields else error "Unknown odbc for numfields." end end, count = function(t) _luadb.iselect(true, t, {"COUNT(*) AS count"}) local row = t:nextrow() return row.count or row.COUNT end, gselect = function(t, ...) _luadb.iselect(false, t, ...) return t.nextrow, t end, grselect = function(t, ...) _luadb.iselect(true, t, ...) return t.nextrow, t end, nextrow = function(t) if t._.db._.odbc_name == "mysql" then return t._.db._.conn:FetchRow() elseif t._.db._.odbc_name == "oracle" then local row = nil if t._.db._.rset:next() ~= 0 then row = {} for k, v in pairs(t._.db._.fieldsNames) do local ftype = get_canon_type(t._.ddl[v] and t._.ddl[v].type or "varchar", v, t._.db) if ftype == "BLOB" row[v] = t._.db._.rset:getBlob(k) else row[v] = t._.db._.rset:getString(k) end end end return row else error "Unknown odbc for nextrow." end end, _restricts = { -- All the rendering is done on-the-fly. -- this is the leaf function; thus this is the only one which ough to get the sql_escape. dop = function(r, field, expr, op) if not r._.tbl._.ddl[field] then error("Trying to restrict on a field which doesn't exist: " .. field .. " - inside of table " .. r._.tbl._.tablename) end if type(expr) == "string" or type(expr) == "number" then return r._.tbl._.tname .. "." .. r._.tbl._.db._.fq .. field .. r._.tbl._.db._.fq .. " " .. op .. " " .. r._.tbl._.db._.eq .. r._.tbl._.db.sql_escape(expr) .. r._.tbl._.db._.eq elseif type(expr) == "table" then error("Complex queries not handled for now") else error("Wrong type for expression in operator: " .. type(expr)) end end, varops = function(r, op, ...) local exprs, first, r, k, v = {...}, true, "" for k, v in pairs(exprs) do if not first then r = r .. " " .. op .. " " end r = r .. v first = false end return r end, }, restricts = { NOT = function(r, expr) return "NOT " .. expr end, AND = function(r, ...) return _luadb._restricts.varops(r, "AND", ...) end, OR = function(r, ...) return _luadb._restricts.varops(r, "OR", ...) end, EQ = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "=") end, NEQ = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "!=") end, GT = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, ">") end, LT = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "<") end, GE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, ">=") end, LE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "<=") end, LIKE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "LIKE") end, IN = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "IN") end, }, restricter = function(t) local r, k, v = { _ = { tbl = t } } for k, v in pairs(_luadb.restricts) do r[k] = v end return r end, restrict = function(t, conditions) -- all in all, "conditions" can just be a direct string. Let the programmer do what he wants, after all. if conditions == nil or conditions == "" then conditions = "1=1" end t._.conditions = conditions end, mysql = { SafeQuery = function(db, str, query_type, no_commit) if luadb.debugmode then print(str) end local r = db._.conn:Query(str) if r ~= 0 and db._.conn:ErrNO() == 2006 then -- disconnected db._.conn = luadb:opendb(db)._.conn r = db._.conn:Query(str) end return r end, Desc = function(db, tbl) if db:SafeQuery("DESC " .. db._.fq .. db._.prefix .. db.sql_escape(tbl) .. db._.fq .. ";") ~= 0 then return nil end local dfields = {} for i = 1, db._.conn:NumRows() do dfields[i] = db._.conn:FetchRow() end return dfields end, }, oracle = { _SafeQuery = function(db, str, query_type, no_commit) local r = -1 if luadb.debugmode then print(str) end if not query_type then error "SafeQuery: Need a query type" end if db._.stmt then if db._.rset then db._.stmt:closeResultSet(db._.rset) db._.rset = nil end db._.conn:terminateStatement(db._.stmt) end local stmt = db._.conn:createStatement() if query_type == "update" or query_type == "raw" then if query_type == "update" then str = str:gsub("%s*(.*)%s*;", "%1") end db._.affectedRows = stmt:executeUpdate(str) if db._.affectedRows == nil then db._.errNO = stmt:getErrorCode() db._.errStr = stmt:getErrorMsg() db._.numFields = 0 db._.fieldsNames = {} db._.rset = nil else db._.errNO = 0 db._.errStr = "No error" db._.numFields = 0 db._.fieldsNames = {} db._.rset = nil r = 0 end elseif query_type == "query" then str = str:gsub("%s*(.*)%s*;", "%1") local rset = stmt:executeQuery(str) if rset then db._.errNO = 0 db._.errStr = "No error" db._.numFields, db._.fieldsNames = rset:getFieldsInfo() db._.rset = rset r = 0 else db._.errNO = stmt:getErrorCode() db._.errStr = stmt:getErrorMsg() db._.numFields = 0 db._.fieldsNames = {} db._.rset = nil end elseif type(query_type) == "string" then error("Unknown query type: " .. query_type) else error("Query type mis-specified (not a string)") end db._.stmt = stmt if r == 0 and not no_commit then db._.conn:commit() end return r end, SafeQuery = function(db, str, query_type, no_commit) local r = _luadb.oracle._SafeQuery(db, str, query_type, no_commit) if r ~= 0 then -- connection timeout, let's reconnect. if db._.errNO == 3135 then if db._.stmt then if db._.rset then db._.stmt:closeResultSet(db._.rset) db._.rset = nil end db._.conn:terminateStatement(db._.stmt) end db._.conn = luadb.oracle.opendb(db)._.conn r = _luadb.oracle._SafeQuery(db, str, query_type, no_commit) end end return r end, oracle_types = { [1] = "VARCHAR", [2] = "NUMBER", [12] = "DATE", [96] = "CHAR", [180] = "TIMESTAMP", [113] = "BLOB", [112] = "CLOB", }, Desc = function(db, tbl) if db:SafeQuery("SELECT DBMS_METADATA.GET_XML('TABLE', '" .. db.sql_escape(tbl) .. "'" .. ((db._.prefix and db._.prefix ~= "") and (", '" .. db.sql_escape(db._.prefix) .. "'") or "") .. ") FROM USER_TABLES;", "query") ~= 0 then if db:ErrNO() == 31603 then return nil end error("Couldn't properly read the description of table " .. tbl .. " - error: " .. db:ErrorStr()) end local n = db._.rset:next() if n == 0 or not n then error("Error retrieving table description: no resultset available.") end local ddl = flatten_xml(xml.LoadHandle(db._.rset:getClob(1))) db._.stmt:closeResultSet(db._.rset) db._.rset = nil local dfields, i = {} local nfields = {} local col_list = ddl.ROWSET.ROW.TABLE_T.COL_LIST.COL_LIST_ITEM if not col_list.__complex then col_list = { col_list } end for i = 1, #col_list do local col = col_list[i] local row = { Field = col.NAME, Null = (col.NOT_NULL + 0) == 0 and "YES" or "NO", Default = col.DEFAULT_VAL and col.DEFAULT_VAL:gsub("'(.-)' *", "%1") or "", Type = _luadb.oracle.oracle_types[col.TYPE_NUM + 0], Key = "", Extra = "", } if not row.Type then error("Unknown TYPE_NUM(" .. col.TYPE_NUM .. ") for column " .. row.Field) end if col.PRECISION_NUM and col.SCALE then row.Type = row.Type .. "(" .. col.PRECISION_NUM .. "," .. col.SCALE .. ")" elseif col.PRECISION_NUM then row.Type = row.Type .. "(" .. col.PRECISION_NUM .. ")" else if row.Type == "NUMBER" and (not col.SCALE or ((col.SCALE + 0) == 0)) then row.Type = "INT" end row.Type = row.Type .. "(" .. col.LENGTH .. ")" end table.insert(dfields, row) nfields[row.Field] = row end -- need to add prefix/owner... if db:SafeQuery("SELECT USER_CONS_COLUMNS.COLUMN_NAME FROM USER_CONSTRAINTS, USER_CONS_COLUMNS WHERE USER_CONSTRAINTS.TABLE_NAME='" .. db.sql_escape(tbl) .. "' AND CONSTRAINT_TYPE='P' AND USER_CONSTRAINTS.CONSTRAINT_NAME = USER_CONS_COLUMNS.CONSTRAINT_NAME;", "query") ~= 0 then error("Unable to retrieve " .. table .. "'s primary keys - error: " .. db:ErrorStr()) end while db._.rset:next() ~= 0 do nfields[db._.rset:getString(1)].Key = "PRI" end db._.stmt:closeResultSet(db._.rset) db._.rset = nil -- need to add prefix/owner... if db:SafeQuery("SELECT SEQUENCE_NAME FROM USER_SEQUENCES;", "query") ~= 0 then error("Unable to retrieve sequences - error: " .. db:ErrorStr()) end local pattern = "^seq_" .. db.sql_escape(tbl) .. "_(%w+)" while db._.rset:next() ~= 0 do local seq_name = db._.rset:getString(1) if seq_name:match(pattern) then nfields[(db._.rset:getString(1):gsub(pattern, "%1"))].Extra = "auto_increment" end end db._.stmt:closeResultSet(db._.rset) db._.rset = nil return dfields end }, } luadb = { mysql = { opendb = function(id, user, password, base, prefix) local db -- if user == nil and password == nil and base == nil then -- get into a registry to fetch "id" -- else -- end if type(id) == "table" and type(id._) == "table" then user = id._.user password = id._.password base = id._.base prefix = id._.prefix id = id._.id end db = SQLConnection(id, user, password, base) return { opentable = _luadb.opentable, SafeQuery = _luadb.mysql.SafeQuery, sql_escape = sql_escape, ErrNO = function(db) return db._.conn:ErrNO() end, Error = function(db) return db._.conn:Error() end, ErrorStr = function(db) return db._.conn:ErrNO() .. " - " .. db._.conn:Error() end, Desc = _luadb.mysql.Desc, _ = { fq = '`', eq = '"', conn = db, id = id, user = user, password = password, base = base, prefix = sql_escape(prefix or ""), odbc_name = "mysql", }, } end, }, oracle = { opendb = function(id, user, password, base, prefix) local l_env, conn if _luadb.oracle.env then l_env = _luadb.oracle.env else l_env = createEnvironment() _luadb.oracle.env = l_env l_env:setExceptions(false) end if type(id) == "table" and type(id._) == "table" then user = id._.user password = id._.password base = id._.base prefix = id._.prefix id = id._.id end conn = l_env:createConnection(user, password, base) if not conn then return nil, l_env:getErrorCode(), l_env:getErrorMsg() end return { opentable = _luadb.opentable, SafeQuery = _luadb.oracle.SafeQuery, sql_escape = sql_escape, ErrNO = function(db) return db._.errNO end, Error = function(db) return db._.errStr end, ErrorStr = function(db) return db._.errStr end, Desc = _luadb.oracle.Desc, _ = { fq = '"', eq = "'", conn = conn, id = id, user = user, password = password, base = base, prefix = sql_escape(prefix or ""), odbc_name = "oracle", }, } end, }, } luadb.opendb = luadb.mysql.opendb