diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/dblib.lua | 698 |
1 files changed, 551 insertions, 147 deletions
diff --git a/lib/dblib.lua b/lib/dblib.lua index 85fc15a..6b4599a 100644 --- a/lib/dblib.lua +++ b/lib/dblib.lua @@ -22,17 +22,17 @@ ]]-- -- dblib - database object layer --- Currently built-in with MySQL - TODO: support some abstraction layer with drivers. +-- Currently built-in with MySQL and Oracle support. MySQL is the default one. --[[ ddl = { field1 = { type = "varchar", length = 32 }, - field2 = { type = "float", length = 11, decimal = 8, default = 42 }, + field2 = { type = "number", length = 11, decimal = 8, default = 42 }, id = { options = "pri,auto" }, } -db = luadb.opendb(host, login, password, database) +db = luadb.mysql.opendb(host, login, password, database) t = db:opentable("foobar", ddl) t:empty() @@ -65,7 +65,6 @@ print(row.field1) -- prints "test1" ]]-- -local _luadb _luadb = { get_options = function(str) local s = split(str, ",") @@ -77,7 +76,7 @@ _luadb = { return r end, - get_canon_type = function(ttype, k) + get_canon_type = function(ttype, k, db) if ttype == nil then return "INT" elseif type(ttype) ~= "string" then @@ -88,12 +87,15 @@ _luadb = { return "INT" elseif ttype:upper() == "BLOB" then return "BLOB" - elseif ttype:upper() == "DATETIME" then - return "DATETIME" elseif ttype:upper() == "TIMESTAMP" then return "TIMESTAMP" + elseif ttype:upper() == "NUMBER" then + if db._.odbc_name == "mysql" then + return "FLOAT" + end + return "NUMBER" elseif ttype:upper() == "FLOAT" then - return "FLOAT" + return "NUMBER" elseif ttype:upper() == "TEXT" then return "TEXT" else @@ -101,16 +103,17 @@ _luadb = { end end, - generate_fields = function(db, ddl) + generate_fields = function(db, tablename, ddl) local k, v - local r, alters, keys = {}, {}, {} + local r, alters, keys, extras = {}, {}, {}, {} for k, v in pairs(ddl) do local ttype r[k] = "" alters[k] = {} - ttype = _luadb.get_canon_type(v.type, k) + ttype = _luadb.get_canon_type(v.type, k, db) + if ttype == "TEXT" and db._.odbc_name == "oracle" then ttype = "CLOB" end r[k] = ttype if v.length == nil then @@ -126,117 +129,154 @@ _luadb = { end if v.options ~= nil and type(v.options) ~= "string" then - error("Wrong data in ddl - " .. k .. ".options isn't a string.") + error("Wrong data in ddl for table " .. tablename .. " - " .. k .. ".options isn't a string.") end local options = v.options and _luadb.get_options(v.options) or {} - if options.NULL then - r[k] = r[k] .. " NULL" - else - r[k] = r[k] .. " NOT NULL" - end - if options.PRI then table.insert(keys, k) keys[k] = true end if options.UNIQ then - table.insert(alters[k], "ADD UNIQUE (`@fieldname@`)") + table.insert(alters[k], "ADD UNIQUE (" .. db._.fq .. "@fieldname@" .. db._.fq .. ")") end if options.AUTO then - r[k] = r[k] .. " AUTO_INCREMENT" + if db._.odbc_name == "mysql" then + r[k] = r[k] .. " AUTO_INCREMENT" + elseif db._.odbc_name == "oracle" then + local seq = "seq_" .. db.sql_escape(tablename .. "_" .. k) + local trg = "trg_" .. db.sql_escape(tablename .. "_" .. k) .. "_auto" + local owner = db._.prefix == "" and "USER()" or ("'" .. db._.prefix .. "'") + local tbl = db.sql_escape(tablename) + local fld = db.sql_escape(k) + table.insert(extras, [[ +DECLARE + P_EXISTS NUMBER; +BEGIN + SELECT COUNT(*) INTO P_EXISTS FROM ALL_SEQUENCES WHERE SEQUENCE_NAME=']] .. seq .. [[' AND SEQUENCE_OWNER=]] .. owner .. [[; + IF P_EXISTS = 0 THEN + EXECUTE IMMEDIATE 'CREATE SEQUENCE "]] .. seq .. [["'; + END IF; +END; +]]) + table.insert(extras, [[ +CREATE OR REPLACE TRIGGER "]] .. trg .. [[" + BEFORE INSERT ON "]] .. tbl .. [[" + FOR EACH ROW BEGIN IF :new."]] .. fld .. [[" IS NULL THEN + SELECT "]] .. seq .. [[".nextval INTO :new."]] .. fld .. [[" FROM DUAL; + END IF; +END; +]]) + else + error("Unknown odbc for auto increment operations.") + end end if v.default then if type(v.default) ~= "string" and type(v.default) ~= "number" then error("Default value for field " .. k .. " isn't usable.") end - r[k] = r[k] .. ' DEFAULT "' .. db.sql_escape(v.default) .. '"' + r[k] = r[k] .. ' DEFAULT ' .. db._.eq .. db.sql_escape(v.default) .. db._.eq elseif ttype == "TIMESTAMP" then r[k] = r[k] .. ' DEFAULT CURRENT_TIMESTAMP' end + + if options.NULL then + r[k] = r[k] .. " NULL" + else + r[k] = r[k] .. " NOT NULL" + end + end - return r, alters, keys + return r, alters, keys, extras end, - opentable = function(db, tablename, ddl) - local fields, alters, keys = _luadb.generate_fields(db, ddl) - local tname = "`" .. db._.prefix .. db.sql_escape(tablename) .. "`" + opentable = function(db, tablename, ddl, force_create) + local fields, alters, keys, extras = _luadb.generate_fields(db, tablename, ddl) + local tname = db._.fq .. db._.prefix .. db.sql_escape(tablename) .. db._.fq local operations = 0 + local dfields = db._.Desc(db, tablename) - if db:SafeQuery("DESC " .. tname) ~= 0 then + if force_create or dfields == nil then -- table doesn't exist, create it - if db._.conn:ErrNO() == 1146 then - local q = "CREATE TABLE " .. tname .. " (" - local k, v, first + local q = "CREATE TABLE " .. tname .. " (" + local k, v, first + first = true + operations = -1 + + for k, v in pairs(fields) do + if not first then + q = q .. ", " + else + first = false + end + q = q .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v + end + + if #keys ~= 0 then + q = q .. ", PRIMARY KEY (" first = true - operations = -1 - - for k, v in pairs(fields) do + for k, v in ipairs(keys) do if not first then q = q .. ", " else first = false end - q = q .. "`" .. db.sql_escape(k) .. "` " .. v + q = q .. db._.fq .. db.sql_escape(v) .. db._.fq end - - if #keys ~= 0 then - q = q .. ", PRIMARY KEY (" - first = true - for k, v in ipairs(keys) do - if not first then - q = q .. ", " - else - first = false - end - q = q .. "`" .. db.sql_escape(v) .. "`" + q = q .. ")" + end + + q = q .. ")" + + if db._.odbc_name == "mysql" then + q = q .. " ENGINE=InnoDB" + end + + q = q .. ";" + + if db:SafeQuery(q, "update") ~= 0 then + error("Error creating table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) + end + + for k, v in pairs(alters) do + local _, stmt + for _, stmt in pairs(v) do + q = "ALTER TABLE " .. tname .. " " .. string.gsub(stmt, "@fieldname@", db.sql_escape(k)) + if db:SafeQuery(q, "update") ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) + db:SafeQuery("DROP TABLE " .. tname .. ";", "update") end - q = q .. ")" - end - - q = q .. ") ENGINE=InnoDB;" - - if db:SafeQuery(q) ~= 0 then - error("Error creating table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) end - - for k, v in pairs(alters) do - local _, stmt - for _, stmt in pairs(v) do - q = "ALTER TABLE " .. tname .. " " .. string.gsub(stmt, "@fieldname@", db.sql_escape(k)) - if db:SafeQuery(q) ~= 0 then - error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) - db:SafeQuery("DROP TABLE " .. tname ";") - end - end + end + + for k, v in pairs(extras) do + if db:SafeQuery(v, "raw") ~= 0 then + error("Error running extra query for " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. v) + db:SafeQuery("DROP TABLE " .. tname .. ";", "update") end - else - error("Error getting description of table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error()) end else -- table exists, let's check it. - local dfields = {} local i local any_common = false - for i = 1, db._.conn:NumRows() do - dfields[i] = db._.conn:FetchRow() + for i = 1, #dfields do if fields[dfields[i].Field] then any_common = true end end if not any_common then - -- strictly no field in common; drop the table and recursively call opentable in order to proceed with a createtable instead. + -- strictly no field in common; drop the table and tail-call opentable in order to proceed with a createtable instead. q = "DROP TABLE " .. tname .. ";" - if db:SafeQuery(q) ~= 0 then - error("Error dropping table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + if db:SafeQuery(q, "update") ~= 0 then + error("Error dropping table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end - return db:opentable(tablename, ddl) + return db:opentable(tablename, ddl, true) end local d, k, v = {} @@ -244,9 +284,9 @@ _luadb = { for k, v in pairs(dfields) do d[v.Field] = v if not fields[v.Field] then - q = "ALTER TABLE " .. tname .. " DROP COLUMN `" .. db.sql_escape(v.Field) .. "`;" - if db:SafeQuery(q) ~= 0 then - error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + q = "ALTER TABLE " .. tname .. " DROP COLUMN " .. db._.fq .. db.sql_escape(v.Field) .. db._.fq .. ";" + if db:SafeQuery(q, "update") ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end @@ -258,11 +298,11 @@ _luadb = { for k, v in pairs(fields) do q = nil if not dfields[k] then - q = "ALTER TABLE " .. tname .. " ADD `" .. db.sql_escape(k) .. "` " .. v .. ";" + q = "ALTER TABLE " .. tname .. " ADD " .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v .. ";" else - local identicals, alters = _luadb.compare_fields(ddl, dfields[k]) + local identicals, alters = _luadb.compare_fields(db, ddl, dfields[k]) if not identicals then - q = "ALTER TABLE " .. tname .. " MODIFY `" .. db.sql_escape(k) .. "` " .. v .. ";" + q = "ALTER TABLE " .. tname .. " MODIFY " .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v .. ";" local _, a for _, a in pairs(alters) do a.stmt = string.gsub(a.stmt, "@fieldname@", db.sql_escape(k)) @@ -270,8 +310,8 @@ _luadb = { table.insert(deffered_alters, a) else q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";" - if db:SafeQuery(q) ~= 0 then - error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + if db:SafeQuery(q, "update") ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end @@ -279,8 +319,8 @@ _luadb = { end end if q then - if db:SafeQuery(q) ~= 0 then - error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + if db:SafeQuery(q, "update") ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end @@ -290,8 +330,8 @@ _luadb = { for k, v in pairs(deffered_alters) do q = "ALTER TABLE " .. tname .. " " .. v.stmt .. ";" - if db:SafeQuery(q) ~= 0 then - error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) + if db:SafeQuery(q, "update") ~= 0 then + error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) end operations = operations + 1 end @@ -325,7 +365,7 @@ _luadb = { }, operations end, - compare_fields = function(ddl, field) + compare_fields = function(db, ddl, field) local f = field.Field local identical = true local alters = {} @@ -333,7 +373,7 @@ _luadb = { error("Internal issue with the DDL: key " .. f .. " isn't a table.") end -- check base type first. - local ddltype = _luadb.get_canon_type(ddl[f].type, f) + local ddltype = _luadb.get_canon_type(ddl[f].type, f, db) local desctype, desclength, descdecim = field.Type:match "(%w+)%((%d+)%,(%d+)%)" if desctype == nil then desctype, desclength = field.Type:match "(%w+)%((%d+)%)" @@ -344,7 +384,7 @@ _luadb = { if desctype == nil then error("Error parsing type string from database: " .. f .. ": " .. field.Type) end - desctype = _luadb.get_canon_type(desctype, f) + desctype = _luadb.get_canon_type(desctype, f, db) if ddltype ~= desctype then identical = false @@ -372,24 +412,32 @@ _luadb = { table.insert(alters, { stmt = "DROP PRIMARY KEY", pri = -2 } ) end - if options.UNI and field.Key ~= "UNI" then - -- needs to add a unique index - identical = false - table.insert(alters, { stmt = "ADD UNIQUE(@fieldname@)", pri = 0 } ) - end + -- Needs oracle version of this... + if db._.odbc_name == "mysql" then + if options.UNI and field.Key ~= "UNI" then + -- needs to add a unique index + identical = false + table.insert(alters, { stmt = "ADD UNIQUE(@fieldname@)", pri = 0 } ) + end - if not options.UNI and field.Key == "UNI" then - -- needs to drop the unique index - identical = false - table.insert(alters, { stmt = "DROP INDEX @fieldname@", pri = 0 } ) + if not options.UNI and field.Key == "UNI" then + -- needs to drop the unique index + identical = false + table.insert(alters, { stmt = "DROP INDEX @fieldname@", pri = 0 } ) + end end - if not options.AUTO and field.Extra == "auto_increment" then - identical = false + -- Need oracle version of this check. + if db._.odbc_name == "mysql" then + if not options.AUTO and field.Extra == "auto_increment" then + identical = false + end end - if options.AUTO and field.Extra ~= "auto_increment" then - identical = false + if db._.odbc_name == "mysql" then + if options.AUTO and field.Extra ~= "auto_increment" then + identical = false + end end return identical, alters @@ -398,6 +446,7 @@ _luadb = { insert = function(t, data) local k, v, stmt, stmt2 local first = true + local got_blobs = false stmt = "INSERT INTO " .. t._.tname .. " (" stmt2 = ") VALUES (" @@ -406,60 +455,139 @@ _luadb = { if not t._.ddl[k] then error("Trying to address a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) end - if not first then - stmt = stmt .. ", " - stmt2 = stmt2 .. ", " + local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) + if t._.db._.odbc_name == "oracle" and (ftype == "BLOB" or ftype == "TEXT") then + got_blobs = true else - first = false + if not first then + stmt = stmt .. ", " + stmt2 = stmt2 .. ", " + else + first = false + end + stmt = stmt .. t._.db._.fq .. k .. t._.db._.fq + stmt2 = stmt2 .. t._.db._.eq .. v .. t._.db._.eq + end + end + + if t._.db._.odbc_name == "oracle" then + local conn, stmt, status = t._.db._.conn + stmt = conn:createStatement(stmt .. stmt2 .. ") RETURNING ROWIDTOCHAR(rowid) INTO :1") + stmt:registerOutParam(1, OCCISTRING, 30, "VARCHAR2") + status = stmt:executeUpdate() + if status ~= 1 then + error("Error inserting a row inside table " .. t._.tablename .. ": " .. stmt:getErrorMsg()) + end + t._.db._.lastId = stmt:getString(1) + conn:terminateStatement(stmt) + else + if t._.db:SafeQuery(stmt .. stmt2 .. ");", "update") ~= 0 then + error("Error inserting a row inside table " .. t._.tablename .. ": " .. t._.db:Error()) end - stmt = stmt .. "`" .. k .. "`" - stmt2 = stmt2 .. "'" .. v .. "'" end - if t._.db:SafeQuery(stmt .. stmt2 .. ");") ~= 0 then - error("Error inserting a row inside table " .. t._.tablename .. ": " .. t._.db:Error()) + if t._.db._.odbc_name == "mysql" then + t._.db._.lastId = t._.db._.conn:InsertId() end - return t._.db._.conn:InsertId() + return t._.db._.lastId end, delete = function(t) local stmt = "DELETE FROM " .. t._.tname .. " WHERE " .. t._.conditions .. ";" - if t._.db:SafeQuery(stmt) ~= 0 then + if t._.db:SafeQuery(stmt, "update") ~= 0 then error("Error deleting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) end - return t._.db._.conn:NumAffectedRows() + if t._.db._.odbc_name == "mysql" then + t._.db._.lastAffectedRows = t._.db._.conn:NumAffectedRows() + elseif t._.db._.odbc_name == "oracle" then + t._.db._.lastAffectedRows = t._.db._.affectedRows + else + error("Don't know how to handle odbc " .. t._.db._.odbc_name .. " in update.") + end + + return t._.db._.lastAffectedRows end, update = function(t, data) local stmt = "UPDATE " .. t._.tname .. " SET " local first = true + local got_blobs = false local k, v for k, v in pairs(data) do if not t._.ddl[k] then error("Trying to update a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) end - if not first then - stmt = stmt .. ", " - else - first = false - end - stmt = stmt .. "`" .. k .. "`=" - if type(v) == "string" or type(v) == "number" then - stmt = stmt .. '"' .. t._.db.sql_escape(v) .. '"' + local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) + if t._.db._.odbc_name == "oracle" and (ftype == "BLOB" or ftype == "TEXT") then + got_blobs = true else - error("Complex UPDATE queries are not supported yet.") + if not first then + stmt = stmt .. ", " + else + first = false + end + stmt = stmt .. t._.db._.fq .. t._.db.sql_escape(k) .. t._.db._.fq .. "=" + if type(v) == "string" or type(v) == "number" then + stmt = stmt .. t._.db._.eq .. t._.db.sql_escape(v) .. t._.db._.eq + else + error("Complex UPDATE queries are not supported yet.") + end end end stmt = stmt .. " WHERE " .. t._.conditions - if t._.db:SafeQuery(stmt) ~= 0 then + if t._.db:SafeQuery(stmt, "update", got_blobs) ~= 0 then error("Error updating row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) end - return t._.db._.conn:NumAffectedRows() + if t._.db._.odbc_name == "mysql" then + t._.db._.lastAffectedRows = t._.db._.conn:NumAffectedRows() + elseif t._.db._.odbc_name == "oracle" then + t._.db._.lastAffectedRows = t._.db._.affectedRows + else + error("Don't know how to handle odbc " .. t._.db._.odbc_name .. " in update.") + end + + if got_blobs then + -- very ugly way to solve this, but... *shrug* + local conn, buffs, count, stmt, rset = t._.db._.conn, {}, 1 + stmt = "SELECT " + first = false + for k, v in pairs(data) do + local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) + if ftype == "BLOB" or ftype == "TEXT" then + if not first then + stmt = stmt .. ", " + else + first = false + end + stmt = stmt .. t._.db._.fq .. t._.db.sql_escape(k) .. t._.db._.fq + end + buffs[count] = Buffer(true) + buffs[count]:write(v) + count = count + 1 + end + stmt = stmt .. " FROM " .. t._.tname .. " WHERE " .. t._.conditions .. " FOR UPDATE" + stmt = conn:createStatement(stmt) + rset = stmt:executeQuery() + while rset:next() ~= 0 do + for k, v in ipairs(buffs) do + rset:setBlob(k, v) + v:seek(0, SEEK_SET) + end + end + for k, v in ipairs(buffs) do + v:destroy() + end + conn:commit() + stmt:closeResultSet(rset) + conn:terminateStatement(stmt) + end + + return t._.db._.lastAffectedRows end, iselect = function(bypass, t, fields, foreign, options) @@ -491,7 +619,7 @@ _luadb = { if bypass then stmt = stmt .. v else - stmt = stmt .. "`" .. v .. "`" + stmt = stmt .. t._.db._.fq .. v .. t._.db._.fq end end end @@ -508,7 +636,7 @@ _luadb = { if not v._.ddl[f.foreign.fieldname] then error("Foreign key points to an unknown field in table: " .. f.foreign.tablename .. "." .. f.foreign.fieldname) end - foreign_conds = foreign_conds .. " AND " .. t._.tname .. ".`" .. fname .. "`=" .. v._.tname .. ".`" .. f.foreign.fieldname .. "`" + foreign_conds = foreign_conds .. " AND " .. t._.tname .. "." .. t._.db._.fq .. fname .. t._.db._.fq "=" .. v._.tname .. "." .. t._.db._.fq .. f.foreign.fieldname .. t._.db._.fq found_table = true end end @@ -542,7 +670,7 @@ _luadb = { error("options.order.dir has to be either 'asc' or 'desc'") end end - stmt = stmt .. " ORDER BY `" .. options.order.by .. "` " .. direction + stmt = stmt .. " ORDER BY " .. t._.db._.fq .. options.order.by .. t._.db._.fq .. " " .. direction end if options.limit then if type(options.limit) ~= "table" then @@ -561,21 +689,27 @@ _luadb = { end size = options.limit.size end - stmt = stmt .. " LIMIT " .. start .. "," .. size + if t._.db._.odbc_name == "mysql" then + stmt = stmt .. " LIMIT " .. start .. "," .. size + elseif t._.db._.odbc_name == "oracle" then + error("LIMIT code for Oracle not implemented yet.") + else + error("No LIMIT code - unknown odbc") + end end end stmt = stmt .. ";" - if t._.db:SafeQuery(stmt) ~= 0 then - error("Error selecting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) + if t._.db:SafeQuery(stmt, "query") ~= 0 then + error("Error selecting row(s) inside table " .. t._.tablename .. ": " .. t._.db:ErrorStr()) end - return t:numrows(), t:numfields() + return (t._.db._.odbc_name == "oracle" and -1 or t:numrows()), t:numfields() end, empty = function(t, fields) - if t._.db:SafeQuery("TRUNCATE TABLE " .. t._.tname) ~= 0 then + if t._.db:SafeQuery("TRUNCATE TABLE " .. t._.tname, "update") ~= 0 then error("Error emptying table " .. t._.tablename .. ": " .. t._.db:Error()) end end, @@ -585,16 +719,29 @@ _luadb = { end, numrows = function(t) - return t._.db._.conn:NumRows() + if t._.db._.odbc_name == "mysql" then + return t._.db._.conn:NumRows() + elseif t._.db._.odbc_name == "oracle" then + error "No 'numrows' for Oracle yet." + else + error "Unknown odbc for numrows." + end end, numfields = function(t) - return t._.db._.conn:NumFields() + if t._.db._.odbc_name == "mysql" then + return t._.db._.conn:NumFields() + elseif t._.db._.odbc_name == "oracle" then + return t._.db._.numFields + else + error "Unknown odbc for numfields." + end end, count = function(t) _luadb.iselect(true, t, {"COUNT(*) AS count"}) - return t:nextrow().count + local row = t:nextrow() + return row.count or row.COUNT end, gselect = function(t, ...) @@ -603,7 +750,20 @@ _luadb = { end, nextrow = function(t) - return t._.db._.conn:FetchRow() + if t._.db._.odbc_name == "mysql" then + return t._.db._.conn:FetchRow() + elseif t._.db._.odbc_name == "oracle" then + local row = nil + if t._.db._.rset:next() ~= 0 then + row = {} + for k, v in t._.db._.fieldsNames do + row[v] = t._.db._.rset:getString(k) + end + end + return row + else + error "Unknown odbc for nextrow." + end end, _restricts = { @@ -615,7 +775,7 @@ _luadb = { end if type(expr) == "string" or type(expr) == "number" then - return r._.tbl._.tname .. ".`" .. field .. "` " .. op .. ' "' .. r._.tbl._.db.sql_escape(expr) .. '"' + return r._.tbl._.tname .. "." .. r._.tbl._.db._.fq .. field .. r._.tbl._.db._.fq .. " " .. op .. " " .. r._.tbl._.db._.eq .. r._.tbl._.db.sql_escape(expr) .. r._.tbl._.db._.eq elseif type(expr) == "table" then error("Complex queries not handled for now") else @@ -671,24 +831,219 @@ _luadb = { t._.conditions = conditions end, - SafeQuery = function(db, str) - local r = db._.conn:Query(str) - if luadb.debugmode then - print(str) - end + mysql = { + SafeQuery = function(db, str, query_type, no_commit) + if luadb.debugmode then + print(str) + end - if r ~= 0 and db._.conn:ErrNO() == 2006 then -- disconnected - db._.conn = luadb:opendb(db)._.conn - r = db._.conn:Query(str) - end + local r = db._.conn:Query(str) + if r ~= 0 and db._.conn:ErrNO() == 2006 then -- disconnected + db._.conn = luadb:opendb(db)._.conn + r = db._.conn:Query(str) + end - return r - end, + return r + end, + + Desc = function(db, tbl) + if db:SafeQuery("DESC " .. db._.fq .. db._.prefix .. db.sql_escape(tbl) .. db._.fq .. ";") ~= 0 then return nil end + + local dfields = {} + for i = 1, db._.conn:NumRows() do + dfields[i] = db._.conn:FetchRow() + end + + return dfields + end, + }, + oracle = { + _SafeQuery = function(db, str, query_type, no_commit) + local r = -1 + + if luadb.debugmode then + print(str) + end + + if not query_type then + error "SafeQuery: Need a query type" + end + + if db._.stmt then + if db._.rset then + db._.stmt:closeResultSet(db._.rset) + db._.rset = nil + end + db._.conn:terminateStatement(db._.stmt) + end + + local stmt = db._.conn:createStatement() + + if query_type == "update" or query_type == "raw" then + if query_type == "update" then + str = str:gsub(" *(.*) *;", "%1") + end + db._.affectedRows = stmt:executeUpdate(str) + if db._.affectedRows == nil then + db._.errNO = stmt:getErrorCode() + db._.errStr = stmt:getErrorMsg() + db._.numFields = 0 + db._.fieldsNames = {} + db._.rset = nil + else + db._.errNO = 0 + db._.errStr = "No error" + db._.numFields = 0 + db._.fieldsNames = {} + db._.rset = nil + r = 0 + end + elseif query_type == "query" then + str = str:gsub(" *(.*) *;", "%1") + local rset = stmt:executeQuery(str) + if rset then + db._.errNO = 0 + db._.errStr = "No error" + db._.numFields, db._.fieldsNames = rset:getFieldsInfo() + db._.rset = rset + r = 0 + else + db._.errNO = stmt:getErrorCode() + db._.errStr = stmt:getErrorMsg() + db._.numFields = 0 + db._.fieldsNames = {} + db._.rset = nil + end + elseif type(query_type) == "string" then + error("Unknown query type: " .. query_type) + else + error("Query type mis-specified (not a string)") + end + + db._.stmt = stmt + + if r == 0 and not no_commit then + db._.conn:commit() + end + + return r + end, + + SafeQuery = function(db, str, query_type, no_commit) + local r = _luadb.oracle._SafeQuery(db, str, query_type, no_commit) + if r ~= 0 then + -- connection timeout, let's reconnect. + if db._.errNO == 3135 then + if db._.stmt then + if db._.rset then + db._.stmt:closeResultSet(db._.rset) + db._.rset = nil + end + db._.conn:terminateStatement(db._.stmt) + end + db._.conn = luadb.oracle.opendb(db)._.conn + r = _luadb.oracle._SafeQuery(db, str, query_type, no_commit) + end + end + + return r + end, + + oracle_types = { + [1] = "VARCHAR", + [2] = "NUMBER", + [12] = "DATE", + [96] = "CHAR", + [180] = "TIMESTAMP", + [113] = "BLOB", + [112] = "CLOB", + }, + + Desc = function(db, tbl) + if db:SafeQuery("SELECT DBMS_METADATA.GET_XML('TABLE', '" .. db.sql_escape(tbl) .. "'" .. ((db._.prefix and db._.prefix ~= "") and (", '" .. db.sql_escape(db._.prefix) .. "'") or "") .. ") FROM USER_TABLES;", "query") ~= 0 then + if db:ErrNO() == 31603 then + return nil + end + error("Couldn't properly read the description of table " .. tbl .. " - error: " .. db:ErrorStr()) + end + local n = db._.rset:next() + if n == 0 or not n then + error("Error retrieving table description: no resultset available.") + end + local ddl = flatten_xml(xml.LoadHandle(db._.rset:getClob(1))) + db._.stmt:closeResultSet(db._.rset) + db._.rset = nil + local dfields, i = {} + local nfields = {} + local col_list = ddl.ROWSET.ROW.TABLE_T.COL_LIST.COL_LIST_ITEM + if not col_list.__complex then + col_list = { col_list } + end + for i = 1, #col_list do + local col = col_list[i] + local row = { + Field = col.NAME, + Null = (col.NOT_NULL + 0) == 0 and "YES" or "NO", + Default = col.DEFAULT_VAL and col.DEFAULT_VAL:gsub("'(.-)' *", "%1") or "", + Type = _luadb.oracle.oracle_types[col.TYPE_NUM + 0], + Key = "", + Extra = "", + } + if not row.Type then + error("Unknown TYPE_NUM(" .. col.TYPE_NUM .. ") for column " .. row.Field) + end + + if col.PRECISION_NUM and col.SCALE then + row.Type = row.Type .. "(" .. col.PRECISION_NUM .. "," .. col.SCALE .. ")" + elseif col.PRECISION_NUM then + row.Type = row.Type .. "(" .. col.PRECISION_NUM .. ")" + else + if row.Type == "NUMBER" and (col.SCALE + 0) == 0 then + row.Type = "INT" + end + row.Type = row.Type .. "(" .. col.LENGTH .. ")" + end + + table.insert(dfields, row) + nfields[row.Field] = row + end + + -- need to add prefix/owner... + if db:SafeQuery("SELECT USER_CONS_COLUMNS.COLUMN_NAME FROM USER_CONSTRAINTS, USER_CONS_COLUMNS WHERE USER_CONSTRAINTS.TABLE_NAME='" .. db.sql_escape(tbl) .. "' AND CONSTRAINT_TYPE='P' AND USER_CONSTRAINTS.CONSTRAINT_NAME = USER_CONS_COLUMNS.CONSTRAINT_NAME;", "query") ~= 0 then + error("Unable to retrieve " .. table .. "'s primary keys - error: " .. db:ErrorStr()) + end + while db._.rset:next() ~= 0 do + nfields[db._.rset:getString(1)].Key = "PRI" + end + + db._.stmt:closeResultSet(db._.rset) + db._.rset = nil + + -- need to add prefix/owner... + if db:SafeQuery("SELECT SEQUENCE_NAME FROM USER_SEQUENCES;", "query") ~= 0 then + error("Unable to retrieve sequences - error: " .. db:ErrorStr()) + end + + local pattern = "^seq_" .. db.sql_escape(tbl) .. "_(%w+)" + + while db._.rset:next() ~= 0 do + local seq_name = db._.rset:getString(1) + if seq_name:match(pattern) then + nfields[(db._.rset:getString(1):gsub(pattern, "%1"))].Extra = "auto_increment" + end + end + + db._.stmt:closeResultSet(db._.rset) + db._.rset = nil + + return dfields + end + }, } luadb = { - opendb = function(id, user, password, base, prefix) + mysql = { opendb = function(id, user, password, base, prefix) local db -- if user == nil and password == nil and base == nil then @@ -705,18 +1060,67 @@ luadb = { db = SQLConnection(id, user, password, base) return { opentable = _luadb.opentable, - SafeQuery = _luadb.SafeQuery, + SafeQuery = _luadb.mysql.SafeQuery, sql_escape = sql_escape, ErrNO = function(db) return db._.conn:ErrNO() end, Error = function(db) return db._.conn:Error() end, + ErrorStr = function(db) return db._.conn:ErrNO() .. " - " .. db._.conn:Error() end, _ = { + fq = '`', + eq = '"', conn = db, id = id, user = user, password = password, base = base, prefix = sql_escape(prefix or ""), + odbc_name = "mysql", + Desc = _luadb.mysql.Desc, }, } - end, + end, }, + + oracle = { opendb = function(id, user, password, base, prefix) + local l_env, conn + + if _luadb.oracle.env then + l_env = _luadb.oracle.env + else + l_env = createEnvironment() + _luadb.oracle.env = l_env + l_env:setExceptions(false) + end + + if type(id) == "table" and type(id._) == "table" then + user = id._.user + password = id._.password + base = id._.base + prefix = id._.prefix + id = id._.id + end + conn = l_env:createConnection(user, password, base) + return { + opentable = _luadb.opentable, + SafeQuery = _luadb.oracle.SafeQuery, + sql_escape = sql_escape, + ErrNO = function(db) return db._.errNO end, + Error = function(db) return db._.errStr end, + ErrorStr = function(db) return db._.errStr end, + _ = { + fq = '"', + eq = "'", + conn = conn, + id = id, + user = user, + password = password, + base = base, + prefix = sql_escape(prefix or ""), + odbc_name = "oracle", + Desc = _luadb.oracle.Desc, + }, + } + + end, }, } + +luadb.opendb = luadb.mysql.opendb |