From cd5af67c1d0266c38ddce37f72e906726a776b9c Mon Sep 17 00:00:00 2001 From: Pixel Date: Tue, 29 Sep 2009 17:15:51 -0700 Subject: First overhaul of the dblib, in order to add oracle support. The Oracle support isn't yet complete or functionnal, but the bases are there, and the MySQL layer still works fine. --- lib/dblib.lua | 698 +++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 551 insertions(+), 147 deletions(-) diff --git a/lib/dblib.lua b/lib/dblib.lua index 85fc15a..6b4599a 100644 --- a/lib/dblib.lua +++ b/lib/dblib.lua @@ -22,17 +22,17 @@ ]]-- -- dblib - database object layer --- Currently built-in with MySQL - TODO: support some abstraction layer with drivers. +-- Currently built-in with MySQL and Oracle support. MySQL is the default one. --[[ ddl = { field1 = { type = "varchar", length = 32 }, - field2 = { type = "float", length = 11, decimal = 8, default = 42 }, + field2 = { type = "number", length = 11, decimal = 8, default = 42 }, id = { options = "pri,auto" }, } -db = luadb.opendb(host, login, password, database) +db = luadb.mysql.opendb(host, login, password, database) t = db:opentable("foobar", ddl) t:empty() @@ -65,7 +65,6 @@ print(row.field1) -- prints "test1" ]]-- -local _luadb _luadb = { get_options = function(str) local s = split(str, ",") @@ -77,7 +76,7 @@ _luadb = { return r end, - get_canon_type = function(ttype, k) + get_canon_type = function(ttype, k, db) if ttype == nil then return "INT" elseif type(ttype) ~= "string" then @@ -88,12 +87,15 @@ _luadb = { return "INT" elseif ttype:upper() == "BLOB" then return "BLOB" - elseif ttype:upper() == "DATETIME" then - return "DATETIME" elseif ttype:upper() == "TIMESTAMP" then return "TIMESTAMP" + elseif ttype:upper() == "NUMBER" then + if db._.odbc_name == "mysql" then + return "FLOAT" + end + return "NUMBER" elseif ttype:upper() == "FLOAT" then - return "FLOAT" + return "NUMBER" elseif ttype:upper() == "TEXT" then return "TEXT" else @@ -101,16 +103,17 @@ _luadb = { end end, - generate_fields = function(db, ddl) + generate_fields = function(db, tablename, ddl) local k, v - local r, alters, keys = {}, {}, {} + local r, alters, keys, extras = {}, {}, {}, {} for k, v in pairs(ddl) do local ttype r[k] = "" alters[k] = {} - ttype = _luadb.get_canon_type(v.type, k) + ttype = _luadb.get_canon_type(v.type, k, db) + if ttype == "TEXT" and db._.odbc_name == "oracle" then ttype = "CLOB" end r[k] = ttype if v.length == nil then @@ -126,117 +129,154 @@ _luadb = { end if v.options ~= nil and type(v.options) ~= "string" then - error("Wrong data in ddl - " .. k .. ".options isn't a string.") + error("Wrong data in ddl for table " .. tablename .. " - " .. k .. ".options isn't a string.") end local options = v.options and _luadb.get_options(v.options) or {} - if options.NULL then - r[k] = r[k] .. " NULL" - else - r[k] = r[k] .. " NOT NULL" - end - if options.PRI then table.insert(keys, k) keys[k] = true end if options.UNIQ then - table.insert(alters[k], "ADD UNIQUE (`@fieldname@`)") + table.insert(alters[k], "ADD UNIQUE (" .. db._.fq .. "@fieldname@" .. db._.fq .. ")") end if options.AUTO then - r[k] = r[k] .. " AUTO_INCREMENT" + if db._.odbc_name == "mysql" then + r[k] = r[k] .. " AUTO_INCREMENT" + elseif db._.odbc_name == "oracle" then + local seq = "seq_" .. db.sql_escape(tablename .. "_" .. k) + local trg = "trg_" .. db.sql_escape(tablename .. "_" .. k) .. "_auto" + local owner = db._.prefix == "" and "USER()" or ("'" .. db._.prefix .. "'") + local tbl = db.sql_escape(tablename) + local fld = db.sql_escape(k) + table.insert(extras, [[ +DECLARE + P_EXISTS NUMBER; +BEGIN + SELECT COUNT(*) INTO P_EXISTS FROM ALL_SEQUENCES WHERE SEQUENCE_NAME=']] .. seq .. [[' AND SEQUENCE_OWNER=]] .. owner .. [[; + IF P_EXISTS = 0 THEN + EXECUTE IMMEDIATE 'CREATE SEQUENCE "]] .. seq .. [["'; + END IF; +END; +]]) + table.insert(extras, [[ +CREATE OR REPLACE TRIGGER "]] .. trg .. [[" + BEFORE INSERT ON "]] .. tbl .. [[" + FOR EACH ROW BEGIN IF :new."]] .. fld .. [[" IS NULL THEN + SELECT "]] .. seq .. [[".nextval INTO :new."]] .. fld .. [[" FROM DUAL; + END IF; +END; +]]) + else + error("Unknown odbc for auto increment operations.") + end end if v.default then if type(v.default) ~= "string" and type(v.default) ~= "number" then error("Default value for field " .. k .. " isn't usable.") end - r[k] = r[k] .. ' DEFAULT "' .. db.sql_escape(v.default) .. '"' + r[k] = r[k] .. ' DEFAULT ' .. db._.eq .. db.sql_escape(v.default) .. db._.eq elseif ttype == "TIMESTAMP" then r[k] = r[k] .. ' DEFAULT CURRENT_TIMESTAMP' end + + if options.NULL then + r[k] = r[k] .. " NULL" + else + r[k] = r[k] .. " NOT NULL" + end + end - return r, alters, keys + return r, alters, keys, extras end, - opentable = function(db, tablename, ddl) - local fields, alters, keys = _luadb.generate_fields(db, ddl) - local tname = "`" .. db._.prefix .. db.sql_escape(tablename) .. "`" + opentable = function(db, tablename, ddl, force_create) + local fields, alters, keys, extras = _luadb.generate_fields(db, tablename, ddl) + local tname = db._.fq .. db._.prefix .. db.sql_escape(tablename) .. db._.fq local operations = 0 + local dfields = db._.Desc(db, tablename) - if db:SafeQuery("DESC " .. tname) ~= 0 then + if force_create or dfields == nil then -- table doesn't exist, create it - if db._.conn:ErrNO() == 1146 then - local q = "CREATE TABLE " .. tname .. " (" - local k, v, first + local q = "CREATE TABLE " .. tname .. " (" + local k, v, first + first = true + operations = -1 + + for k, v in pairs(fields) do + if not first then + q = q .. ", " + else + first = false + end + q = q .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v + end + + if #keys ~= 0 then + q = q .. ", PRIMARY KEY (" first = true - operations = -1 - - for k, v in pairs(fields) do + for k, v in ipairs(keys) do if not first then q = q .. ", " else first = false end - q = q .. "`" .. db.sql_escape(k) .. "` " .. v + q = q .. db._.fq .. db.sql_escape(v) .. db._.fq end - - if #keys ~= 0 then - q = q .. ", PRIMARY KEY (" - first = true - for k, v in ipairs(keys) do - if not first then - q = q .. ", " - else - first = false - end - q = q .. "`" .. db.sql_escape(v) .. "`" + q = q .. ")" + end + + q = q .. ")" + + if db._.odbc_name == "mysql" then + q = q .. " ENGINE=InnoDB" + end + + q = q .. ";" + + if db:SafeQuery(q, "update") ~= 0 then + error("Error creating table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) + end + + for k, v in pairs(alters) do + local _, stmt + for _, stmt in pairs(v) do + q = "ALTER TABLE " .. tname .. " " .. string.gsub(stmt, "@fieldname@", db.sql_escape(k)) + if db:SafeQuery(q, "update") ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) + db:SafeQuery("DROP TABLE " .. tname .. ";", "update") end - q = q .. ")" - end - - q = q .. ") ENGINE=InnoDB;" - - if db:SafeQuery(q) ~= 0 then - error("Error creating table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) end - - for k, v in pairs(alters) do - local _, stmt - for _, stmt in pairs(v) do - q = "ALTER TABLE " .. tname .. " " .. string.gsub(stmt, "@fieldname@", db.sql_escape(k)) - if db:SafeQuery(q) ~= 0 then - error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) - db:SafeQuery("DROP TABLE " .. tname ";") - end - end + end + + for k, v in pairs(extras) do + if db:SafeQuery(v, "raw") ~= 0 then + error("Error running extra query for " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. v) + db:SafeQuery("DROP TABLE " .. tname .. ";", "update") end - else - error("Error getting description of table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error()) end else -- table exists, let's check it. - local dfields = {} local i local any_common = false - for i = 1, db._.conn:NumRows() do - dfields[i] = db._.conn:FetchRow() + for i = 1, #dfields do if fields[dfields[i].Field] then any_common = true end end if not any_common then - -- strictly no field in common; drop the table and recursively call opentable in order to proceed with a createtable instead. + -- strictly no field in common; drop the table and tail-call opentable in order to proceed with a createtable instead. q = "DROP TABLE " .. tname .. ";" - if db:SafeQuery(q) ~= 0 then - error("Error dropping table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + if db:SafeQuery(q, "update") ~= 0 then + error("Error dropping table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end - return db:opentable(tablename, ddl) + return db:opentable(tablename, ddl, true) end local d, k, v = {} @@ -244,9 +284,9 @@ _luadb = { for k, v in pairs(dfields) do d[v.Field] = v if not fields[v.Field] then - q = "ALTER TABLE " .. tname .. " DROP COLUMN `" .. db.sql_escape(v.Field) .. "`;" - if db:SafeQuery(q) ~= 0 then - error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + q = "ALTER TABLE " .. tname .. " DROP COLUMN " .. db._.fq .. db.sql_escape(v.Field) .. db._.fq .. ";" + if db:SafeQuery(q, "update") ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end @@ -258,11 +298,11 @@ _luadb = { for k, v in pairs(fields) do q = nil if not dfields[k] then - q = "ALTER TABLE " .. tname .. " ADD `" .. db.sql_escape(k) .. "` " .. v .. ";" + q = "ALTER TABLE " .. tname .. " ADD " .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v .. ";" else - local identicals, alters = _luadb.compare_fields(ddl, dfields[k]) + local identicals, alters = _luadb.compare_fields(db, ddl, dfields[k]) if not identicals then - q = "ALTER TABLE " .. tname .. " MODIFY `" .. db.sql_escape(k) .. "` " .. v .. ";" + q = "ALTER TABLE " .. tname .. " MODIFY " .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v .. ";" local _, a for _, a in pairs(alters) do a.stmt = string.gsub(a.stmt, "@fieldname@", db.sql_escape(k)) @@ -270,8 +310,8 @@ _luadb = { table.insert(deffered_alters, a) else q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";" - if db:SafeQuery(q) ~= 0 then - error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + if db:SafeQuery(q, "update") ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end @@ -279,8 +319,8 @@ _luadb = { end end if q then - if db:SafeQuery(q) ~= 0 then - error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + if db:SafeQuery(q, "update") ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end @@ -290,8 +330,8 @@ _luadb = { for k, v in pairs(deffered_alters) do q = "ALTER TABLE " .. tname .. " " .. v.stmt .. ";" - if db:SafeQuery(q) ~= 0 then - error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + if db:SafeQuery(q, "update") ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end @@ -325,7 +365,7 @@ _luadb = { }, operations end, - compare_fields = function(ddl, field) + compare_fields = function(db, ddl, field) local f = field.Field local identical = true local alters = {} @@ -333,7 +373,7 @@ _luadb = { error("Internal issue with the DDL: key " .. f .. " isn't a table.") end -- check base type first. - local ddltype = _luadb.get_canon_type(ddl[f].type, f) + local ddltype = _luadb.get_canon_type(ddl[f].type, f, db) local desctype, desclength, descdecim = field.Type:match "(%w+)%((%d+)%,(%d+)%)" if desctype == nil then desctype, desclength = field.Type:match "(%w+)%((%d+)%)" @@ -344,7 +384,7 @@ _luadb = { if desctype == nil then error("Error parsing type string from database: " .. f .. ": " .. field.Type) end - desctype = _luadb.get_canon_type(desctype, f) + desctype = _luadb.get_canon_type(desctype, f, db) if ddltype ~= desctype then identical = false @@ -372,24 +412,32 @@ _luadb = { table.insert(alters, { stmt = "DROP PRIMARY KEY", pri = -2 } ) end - if options.UNI and field.Key ~= "UNI" then - -- needs to add a unique index - identical = false - table.insert(alters, { stmt = "ADD UNIQUE(@fieldname@)", pri = 0 } ) - end + -- Needs oracle version of this... + if db._.odbc_name == "mysql" then + if options.UNI and field.Key ~= "UNI" then + -- needs to add a unique index + identical = false + table.insert(alters, { stmt = "ADD UNIQUE(@fieldname@)", pri = 0 } ) + end - if not options.UNI and field.Key == "UNI" then - -- needs to drop the unique index - identical = false - table.insert(alters, { stmt = "DROP INDEX @fieldname@", pri = 0 } ) + if not options.UNI and field.Key == "UNI" then + -- needs to drop the unique index + identical = false + table.insert(alters, { stmt = "DROP INDEX @fieldname@", pri = 0 } ) + end end - if not options.AUTO and field.Extra == "auto_increment" then - identical = false + -- Need oracle version of this check. + if db._.odbc_name == "mysql" then + if not options.AUTO and field.Extra == "auto_increment" then + identical = false + end end - if options.AUTO and field.Extra ~= "auto_increment" then - identical = false + if db._.odbc_name == "mysql" then + if options.AUTO and field.Extra ~= "auto_increment" then + identical = false + end end return identical, alters @@ -398,6 +446,7 @@ _luadb = { insert = function(t, data) local k, v, stmt, stmt2 local first = true + local got_blobs = false stmt = "INSERT INTO " .. t._.tname .. " (" stmt2 = ") VALUES (" @@ -406,60 +455,139 @@ _luadb = { if not t._.ddl[k] then error("Trying to address a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) end - if not first then - stmt = stmt .. ", " - stmt2 = stmt2 .. ", " + local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) + if t._.db._.odbc_name == "oracle" and (ftype == "BLOB" or ftype == "TEXT") then + got_blobs = true else - first = false + if not first then + stmt = stmt .. ", " + stmt2 = stmt2 .. ", " + else + first = false + end + stmt = stmt .. t._.db._.fq .. k .. t._.db._.fq + stmt2 = stmt2 .. t._.db._.eq .. v .. t._.db._.eq + end + end + + if t._.db._.odbc_name == "oracle" then + local conn, stmt, status = t._.db._.conn + stmt = conn:createStatement(stmt .. stmt2 .. ") RETURNING ROWIDTOCHAR(rowid) INTO :1") + stmt:registerOutParam(1, OCCISTRING, 30, "VARCHAR2") + status = stmt:executeUpdate() + if status ~= 1 then + error("Error inserting a row inside table " .. t._.tablename .. ": " .. stmt:getErrorMsg()) + end + t._.db._.lastId = stmt:getString(1) + conn:terminateStatement(stmt) + else + if t._.db:SafeQuery(stmt .. stmt2 .. ");", "update") ~= 0 then + error("Error inserting a row inside table " .. t._.tablename .. ": " .. t._.db:Error()) end - stmt = stmt .. "`" .. k .. "`" - stmt2 = stmt2 .. "'" .. v .. "'" end - if t._.db:SafeQuery(stmt .. stmt2 .. ");") ~= 0 then - error("Error inserting a row inside table " .. t._.tablename .. ": " .. t._.db:Error()) + if t._.db._.odbc_name == "mysql" then + t._.db._.lastId = t._.db._.conn:InsertId() end - return t._.db._.conn:InsertId() + return t._.db._.lastId end, delete = function(t) local stmt = "DELETE FROM " .. t._.tname .. " WHERE " .. t._.conditions .. ";" - if t._.db:SafeQuery(stmt) ~= 0 then + if t._.db:SafeQuery(stmt, "update") ~= 0 then error("Error deleting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) end - return t._.db._.conn:NumAffectedRows() + if t._.db._.odbc_name == "mysql" then + t._.db._.lastAffectedRows = t._.db._.conn:NumAffectedRows() + elseif t._.db._.odbc_name == "oracle" then + t._.db._.lastAffectedRows = t._.db._.affectedRows + else + error("Don't know how to handle odbc " .. t._.db._.odbc_name .. " in update.") + end + + return t._.db._.lastAffectedRows end, update = function(t, data) local stmt = "UPDATE " .. t._.tname .. " SET " local first = true + local got_blobs = false local k, v for k, v in pairs(data) do if not t._.ddl[k] then error("Trying to update a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) end - if not first then - stmt = stmt .. ", " - else - first = false - end - stmt = stmt .. "`" .. k .. "`=" - if type(v) == "string" or type(v) == "number" then - stmt = stmt .. '"' .. t._.db.sql_escape(v) .. '"' + local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) + if t._.db._.odbc_name == "oracle" and (ftype == "BLOB" or ftype == "TEXT") then + got_blobs = true else - error("Complex UPDATE queries are not supported yet.") + if not first then + stmt = stmt .. ", " + else + first = false + end + stmt = stmt .. t._.db._.fq .. t._.db.sql_escape(k) .. t._.db._.fq .. "=" + if type(v) == "string" or type(v) == "number" then + stmt = stmt .. t._.db._.eq .. t._.db.sql_escape(v) .. t._.db._.eq + else + error("Complex UPDATE queries are not supported yet.") + end end end stmt = stmt .. " WHERE " .. t._.conditions - if t._.db:SafeQuery(stmt) ~= 0 then + if t._.db:SafeQuery(stmt, "update", got_blobs) ~= 0 then error("Error updating row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) end - return t._.db._.conn:NumAffectedRows() + if t._.db._.odbc_name == "mysql" then + t._.db._.lastAffectedRows = t._.db._.conn:NumAffectedRows() + elseif t._.db._.odbc_name == "oracle" then + t._.db._.lastAffectedRows = t._.db._.affectedRows + else + error("Don't know how to handle odbc " .. t._.db._.odbc_name .. " in update.") + end + + if got_blobs then + -- very ugly way to solve this, but... *shrug* + local conn, buffs, count, stmt, rset = t._.db._.conn, {}, 1 + stmt = "SELECT " + first = false + for k, v in pairs(data) do + local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) + if ftype == "BLOB" or ftype == "TEXT" then + if not first then + stmt = stmt .. ", " + else + first = false + end + stmt = stmt .. t._.db._.fq .. t._.db.sql_escape(k) .. t._.db._.fq + end + buffs[count] = Buffer(true) + buffs[count]:write(v) + count = count + 1 + end + stmt = stmt .. " FROM " .. t._.tname .. " WHERE " .. t._.conditions .. " FOR UPDATE" + stmt = conn:createStatement(stmt) + rset = stmt:executeQuery() + while rset:next() ~= 0 do + for k, v in ipairs(buffs) do + rset:setBlob(k, v) + v:seek(0, SEEK_SET) + end + end + for k, v in ipairs(buffs) do + v:destroy() + end + conn:commit() + stmt:closeResultSet(rset) + conn:terminateStatement(stmt) + end + + return t._.db._.lastAffectedRows end, iselect = function(bypass, t, fields, foreign, options) @@ -491,7 +619,7 @@ _luadb = { if bypass then stmt = stmt .. v else - stmt = stmt .. "`" .. v .. "`" + stmt = stmt .. t._.db._.fq .. v .. t._.db._.fq end end end @@ -508,7 +636,7 @@ _luadb = { if not v._.ddl[f.foreign.fieldname] then error("Foreign key points to an unknown field in table: " .. f.foreign.tablename .. "." .. f.foreign.fieldname) end - foreign_conds = foreign_conds .. " AND " .. t._.tname .. ".`" .. fname .. "`=" .. v._.tname .. ".`" .. f.foreign.fieldname .. "`" + foreign_conds = foreign_conds .. " AND " .. t._.tname .. "." .. t._.db._.fq .. fname .. t._.db._.fq "=" .. v._.tname .. "." .. t._.db._.fq .. f.foreign.fieldname .. t._.db._.fq found_table = true end end @@ -542,7 +670,7 @@ _luadb = { error("options.order.dir has to be either 'asc' or 'desc'") end end - stmt = stmt .. " ORDER BY `" .. options.order.by .. "` " .. direction + stmt = stmt .. " ORDER BY " .. t._.db._.fq .. options.order.by .. t._.db._.fq .. " " .. direction end if options.limit then if type(options.limit) ~= "table" then @@ -561,21 +689,27 @@ _luadb = { end size = options.limit.size end - stmt = stmt .. " LIMIT " .. start .. "," .. size + if t._.db._.odbc_name == "mysql" then + stmt = stmt .. " LIMIT " .. start .. "," .. size + elseif t._.db._.odbc_name == "oracle" then + error("LIMIT code for Oracle not implemented yet.") + else + error("No LIMIT code - unknown odbc") + end end end stmt = stmt .. ";" - if t._.db:SafeQuery(stmt) ~= 0 then - error("Error selecting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) + if t._.db:SafeQuery(stmt, "query") ~= 0 then + error("Error selecting row(s) inside table " .. t._.tablename .. ": " .. t._.db:ErrorStr()) end - return t:numrows(), t:numfields() + return (t._.db._.odbc_name == "oracle" and -1 or t:numrows()), t:numfields() end, empty = function(t, fields) - if t._.db:SafeQuery("TRUNCATE TABLE " .. t._.tname) ~= 0 then + if t._.db:SafeQuery("TRUNCATE TABLE " .. t._.tname, "update") ~= 0 then error("Error emptying table " .. t._.tablename .. ": " .. t._.db:Error()) end end, @@ -585,16 +719,29 @@ _luadb = { end, numrows = function(t) - return t._.db._.conn:NumRows() + if t._.db._.odbc_name == "mysql" then + return t._.db._.conn:NumRows() + elseif t._.db._.odbc_name == "oracle" then + error "No 'numrows' for Oracle yet." + else + error "Unknown odbc for numrows." + end end, numfields = function(t) - return t._.db._.conn:NumFields() + if t._.db._.odbc_name == "mysql" then + return t._.db._.conn:NumFields() + elseif t._.db._.odbc_name == "oracle" then + return t._.db._.numFields + else + error "Unknown odbc for numfields." + end end, count = function(t) _luadb.iselect(true, t, {"COUNT(*) AS count"}) - return t:nextrow().count + local row = t:nextrow() + return row.count or row.COUNT end, gselect = function(t, ...) @@ -603,7 +750,20 @@ _luadb = { end, nextrow = function(t) - return t._.db._.conn:FetchRow() + if t._.db._.odbc_name == "mysql" then + return t._.db._.conn:FetchRow() + elseif t._.db._.odbc_name == "oracle" then + local row = nil + if t._.db._.rset:next() ~= 0 then + row = {} + for k, v in t._.db._.fieldsNames do + row[v] = t._.db._.rset:getString(k) + end + end + return row + else + error "Unknown odbc for nextrow." + end end, _restricts = { @@ -615,7 +775,7 @@ _luadb = { end if type(expr) == "string" or type(expr) == "number" then - return r._.tbl._.tname .. ".`" .. field .. "` " .. op .. ' "' .. r._.tbl._.db.sql_escape(expr) .. '"' + return r._.tbl._.tname .. "." .. r._.tbl._.db._.fq .. field .. r._.tbl._.db._.fq .. " " .. op .. " " .. r._.tbl._.db._.eq .. r._.tbl._.db.sql_escape(expr) .. r._.tbl._.db._.eq elseif type(expr) == "table" then error("Complex queries not handled for now") else @@ -671,24 +831,219 @@ _luadb = { t._.conditions = conditions end, - SafeQuery = function(db, str) - local r = db._.conn:Query(str) - if luadb.debugmode then - print(str) - end + mysql = { + SafeQuery = function(db, str, query_type, no_commit) + if luadb.debugmode then + print(str) + end - if r ~= 0 and db._.conn:ErrNO() == 2006 then -- disconnected - db._.conn = luadb:opendb(db)._.conn - r = db._.conn:Query(str) - end + local r = db._.conn:Query(str) + if r ~= 0 and db._.conn:ErrNO() == 2006 then -- disconnected + db._.conn = luadb:opendb(db)._.conn + r = db._.conn:Query(str) + end - return r - end, + return r + end, + + Desc = function(db, tbl) + if db:SafeQuery("DESC " .. db._.fq .. db._.prefix .. db.sql_escape(tbl) .. db._.fq .. ";") ~= 0 then return nil end + + local dfields = {} + for i = 1, db._.conn:NumRows() do + dfields[i] = db._.conn:FetchRow() + end + + return dfields + end, + }, + oracle = { + _SafeQuery = function(db, str, query_type, no_commit) + local r = -1 + + if luadb.debugmode then + print(str) + end + + if not query_type then + error "SafeQuery: Need a query type" + end + + if db._.stmt then + if db._.rset then + db._.stmt:closeResultSet(db._.rset) + db._.rset = nil + end + db._.conn:terminateStatement(db._.stmt) + end + + local stmt = db._.conn:createStatement() + + if query_type == "update" or query_type == "raw" then + if query_type == "update" then + str = str:gsub(" *(.*) *;", "%1") + end + db._.affectedRows = stmt:executeUpdate(str) + if db._.affectedRows == nil then + db._.errNO = stmt:getErrorCode() + db._.errStr = stmt:getErrorMsg() + db._.numFields = 0 + db._.fieldsNames = {} + db._.rset = nil + else + db._.errNO = 0 + db._.errStr = "No error" + db._.numFields = 0 + db._.fieldsNames = {} + db._.rset = nil + r = 0 + end + elseif query_type == "query" then + str = str:gsub(" *(.*) *;", "%1") + local rset = stmt:executeQuery(str) + if rset then + db._.errNO = 0 + db._.errStr = "No error" + db._.numFields, db._.fieldsNames = rset:getFieldsInfo() + db._.rset = rset + r = 0 + else + db._.errNO = stmt:getErrorCode() + db._.errStr = stmt:getErrorMsg() + db._.numFields = 0 + db._.fieldsNames = {} + db._.rset = nil + end + elseif type(query_type) == "string" then + error("Unknown query type: " .. query_type) + else + error("Query type mis-specified (not a string)") + end + + db._.stmt = stmt + + if r == 0 and not no_commit then + db._.conn:commit() + end + + return r + end, + + SafeQuery = function(db, str, query_type, no_commit) + local r = _luadb.oracle._SafeQuery(db, str, query_type, no_commit) + if r ~= 0 then + -- connection timeout, let's reconnect. + if db._.errNO == 3135 then + if db._.stmt then + if db._.rset then + db._.stmt:closeResultSet(db._.rset) + db._.rset = nil + end + db._.conn:terminateStatement(db._.stmt) + end + db._.conn = luadb.oracle.opendb(db)._.conn + r = _luadb.oracle._SafeQuery(db, str, query_type, no_commit) + end + end + + return r + end, + + oracle_types = { + [1] = "VARCHAR", + [2] = "NUMBER", + [12] = "DATE", + [96] = "CHAR", + [180] = "TIMESTAMP", + [113] = "BLOB", + [112] = "CLOB", + }, + + Desc = function(db, tbl) + if db:SafeQuery("SELECT DBMS_METADATA.GET_XML('TABLE', '" .. db.sql_escape(tbl) .. "'" .. ((db._.prefix and db._.prefix ~= "") and (", '" .. db.sql_escape(db._.prefix) .. "'") or "") .. ") FROM USER_TABLES;", "query") ~= 0 then + if db:ErrNO() == 31603 then + return nil + end + error("Couldn't properly read the description of table " .. tbl .. " - error: " .. db:ErrorStr()) + end + local n = db._.rset:next() + if n == 0 or not n then + error("Error retrieving table description: no resultset available.") + end + local ddl = flatten_xml(xml.LoadHandle(db._.rset:getClob(1))) + db._.stmt:closeResultSet(db._.rset) + db._.rset = nil + local dfields, i = {} + local nfields = {} + local col_list = ddl.ROWSET.ROW.TABLE_T.COL_LIST.COL_LIST_ITEM + if not col_list.__complex then + col_list = { col_list } + end + for i = 1, #col_list do + local col = col_list[i] + local row = { + Field = col.NAME, + Null = (col.NOT_NULL + 0) == 0 and "YES" or "NO", + Default = col.DEFAULT_VAL and col.DEFAULT_VAL:gsub("'(.-)' *", "%1") or "", + Type = _luadb.oracle.oracle_types[col.TYPE_NUM + 0], + Key = "", + Extra = "", + } + if not row.Type then + error("Unknown TYPE_NUM(" .. col.TYPE_NUM .. ") for column " .. row.Field) + end + + if col.PRECISION_NUM and col.SCALE then + row.Type = row.Type .. "(" .. col.PRECISION_NUM .. "," .. col.SCALE .. ")" + elseif col.PRECISION_NUM then + row.Type = row.Type .. "(" .. col.PRECISION_NUM .. ")" + else + if row.Type == "NUMBER" and (col.SCALE + 0) == 0 then + row.Type = "INT" + end + row.Type = row.Type .. "(" .. col.LENGTH .. ")" + end + + table.insert(dfields, row) + nfields[row.Field] = row + end + + -- need to add prefix/owner... + if db:SafeQuery("SELECT USER_CONS_COLUMNS.COLUMN_NAME FROM USER_CONSTRAINTS, USER_CONS_COLUMNS WHERE USER_CONSTRAINTS.TABLE_NAME='" .. db.sql_escape(tbl) .. "' AND CONSTRAINT_TYPE='P' AND USER_CONSTRAINTS.CONSTRAINT_NAME = USER_CONS_COLUMNS.CONSTRAINT_NAME;", "query") ~= 0 then + error("Unable to retrieve " .. table .. "'s primary keys - error: " .. db:ErrorStr()) + end + while db._.rset:next() ~= 0 do + nfields[db._.rset:getString(1)].Key = "PRI" + end + + db._.stmt:closeResultSet(db._.rset) + db._.rset = nil + + -- need to add prefix/owner... + if db:SafeQuery("SELECT SEQUENCE_NAME FROM USER_SEQUENCES;", "query") ~= 0 then + error("Unable to retrieve sequences - error: " .. db:ErrorStr()) + end + + local pattern = "^seq_" .. db.sql_escape(tbl) .. "_(%w+)" + + while db._.rset:next() ~= 0 do + local seq_name = db._.rset:getString(1) + if seq_name:match(pattern) then + nfields[(db._.rset:getString(1):gsub(pattern, "%1"))].Extra = "auto_increment" + end + end + + db._.stmt:closeResultSet(db._.rset) + db._.rset = nil + + return dfields + end + }, } luadb = { - opendb = function(id, user, password, base, prefix) + mysql = { opendb = function(id, user, password, base, prefix) local db -- if user == nil and password == nil and base == nil then @@ -705,18 +1060,67 @@ luadb = { db = SQLConnection(id, user, password, base) return { opentable = _luadb.opentable, - SafeQuery = _luadb.SafeQuery, + SafeQuery = _luadb.mysql.SafeQuery, sql_escape = sql_escape, ErrNO = function(db) return db._.conn:ErrNO() end, Error = function(db) return db._.conn:Error() end, + ErrorStr = function(db) return db._.conn:ErrNO() .. " - " .. db._.conn:Error() end, _ = { + fq = '`', + eq = '"', conn = db, id = id, user = user, password = password, base = base, prefix = sql_escape(prefix or ""), + odbc_name = "mysql", + Desc = _luadb.mysql.Desc, }, } - end, + end, }, + + oracle = { opendb = function(id, user, password, base, prefix) + local l_env, conn + + if _luadb.oracle.env then + l_env = _luadb.oracle.env + else + l_env = createEnvironment() + _luadb.oracle.env = l_env + l_env:setExceptions(false) + end + + if type(id) == "table" and type(id._) == "table" then + user = id._.user + password = id._.password + base = id._.base + prefix = id._.prefix + id = id._.id + end + conn = l_env:createConnection(user, password, base) + return { + opentable = _luadb.opentable, + SafeQuery = _luadb.oracle.SafeQuery, + sql_escape = sql_escape, + ErrNO = function(db) return db._.errNO end, + Error = function(db) return db._.errStr end, + ErrorStr = function(db) return db._.errStr end, + _ = { + fq = '"', + eq = "'", + conn = conn, + id = id, + user = user, + password = password, + base = base, + prefix = sql_escape(prefix or ""), + odbc_name = "oracle", + Desc = _luadb.oracle.Desc, + }, + } + + end, }, } + +luadb.opendb = luadb.mysql.opendb -- cgit v1.2.3