diff options
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/dblib.lua | 698 | 
1 files changed, 551 insertions, 147 deletions
diff --git a/lib/dblib.lua b/lib/dblib.lua index 85fc15a..6b4599a 100644 --- a/lib/dblib.lua +++ b/lib/dblib.lua @@ -22,17 +22,17 @@  ]]--  -- dblib - database object layer --- Currently built-in with MySQL - TODO: support some abstraction layer with drivers. +-- Currently built-in with MySQL and Oracle support. MySQL is the default one.  --[[  ddl = {      field1 = { type = "varchar", length = 32 }, -    field2 = { type = "float", length = 11, decimal = 8, default = 42 }, +    field2 = { type = "number", length = 11, decimal = 8, default = 42 },      id = { options = "pri,auto" },  } -db = luadb.opendb(host, login, password, database) +db = luadb.mysql.opendb(host, login, password, database)  t = db:opentable("foobar", ddl)  t:empty() @@ -65,7 +65,6 @@ print(row.field1)                     -- prints "test1"  ]]-- -local _luadb  _luadb = {      get_options = function(str)          local s = split(str, ",") @@ -77,7 +76,7 @@ _luadb = {          return r      end, -    get_canon_type = function(ttype, k) +    get_canon_type = function(ttype, k, db)          if ttype == nil then              return "INT"          elseif type(ttype) ~= "string" then @@ -88,12 +87,15 @@ _luadb = {              return "INT"          elseif ttype:upper() == "BLOB" then              return "BLOB" -        elseif ttype:upper() == "DATETIME" then -            return "DATETIME"          elseif ttype:upper() == "TIMESTAMP" then              return "TIMESTAMP" +        elseif ttype:upper() == "NUMBER" then +            if db._.odbc_name == "mysql" then +                return "FLOAT" +            end +            return "NUMBER"          elseif ttype:upper() == "FLOAT" then -            return "FLOAT" +            return "NUMBER"          elseif ttype:upper() == "TEXT" then              return "TEXT"          else @@ -101,16 +103,17 @@ _luadb = {          end                  end, -    generate_fields = function(db, ddl) +    generate_fields = function(db, tablename, ddl)          local k, v -        local r, alters, keys = {}, {}, {} +        local r, alters, keys, extras = {}, {}, {}, {}          for k, v in pairs(ddl) do              local ttype              r[k] = ""              alters[k] = {} -            ttype = _luadb.get_canon_type(v.type, k) +            ttype = _luadb.get_canon_type(v.type, k, db) +            if ttype == "TEXT" and db._.odbc_name == "oracle" then ttype = "CLOB" end              r[k] = ttype              if v.length == nil then @@ -126,117 +129,154 @@ _luadb = {              end              if v.options ~= nil and type(v.options) ~= "string" then -                error("Wrong data in ddl - " .. k .. ".options isn't a string.") +                error("Wrong data in ddl for table " .. tablename .. " - " .. k .. ".options isn't a string.")              end              local options = v.options and _luadb.get_options(v.options) or {} -            if options.NULL then -                r[k] = r[k] .. " NULL" -            else -                r[k] = r[k] .. " NOT NULL" -            end -                          if options.PRI then                  table.insert(keys, k)                  keys[k] = true              end              if options.UNIQ then -                table.insert(alters[k], "ADD UNIQUE (`@fieldname@`)") +                table.insert(alters[k], "ADD UNIQUE (" .. db._.fq .. "@fieldname@" .. db._.fq .. ")")              end              if options.AUTO then -                r[k] = r[k] .. " AUTO_INCREMENT" +                if db._.odbc_name == "mysql" then +                    r[k] = r[k] .. " AUTO_INCREMENT" +                elseif db._.odbc_name == "oracle" then +                    local seq = "seq_" .. db.sql_escape(tablename .. "_" .. k) +                    local trg = "trg_" .. db.sql_escape(tablename .. "_" .. k) .. "_auto" +                    local owner = db._.prefix == "" and "USER()" or ("'" .. db._.prefix .. "'") +                    local tbl = db.sql_escape(tablename) +                    local fld = db.sql_escape(k) +                    table.insert(extras, [[ +DECLARE +    P_EXISTS NUMBER; +BEGIN +    SELECT COUNT(*) INTO P_EXISTS FROM ALL_SEQUENCES WHERE SEQUENCE_NAME=']] .. seq .. [[' AND SEQUENCE_OWNER=]] .. owner .. [[; +    IF P_EXISTS = 0 THEN +        EXECUTE IMMEDIATE 'CREATE SEQUENCE "]] .. seq .. [["'; +    END IF; +END; +]]) +                    table.insert(extras, [[ +CREATE OR REPLACE TRIGGER "]] .. trg .. [[" +    BEFORE INSERT ON "]] .. tbl .. [[" +    FOR EACH ROW BEGIN IF :new."]] .. fld .. [[" IS NULL THEN +        SELECT "]] .. seq .. [[".nextval INTO :new."]] .. fld .. [[" FROM DUAL; +    END IF; +END; +]]) +                else +                    error("Unknown odbc for auto increment operations.") +                end              end              if v.default then                  if type(v.default) ~= "string" and type(v.default) ~= "number" then                      error("Default value for field " .. k .. " isn't usable.")                  end -                r[k] = r[k] .. ' DEFAULT "' .. db.sql_escape(v.default) .. '"' +                r[k] = r[k] .. ' DEFAULT ' .. db._.eq .. db.sql_escape(v.default) .. db._.eq              elseif ttype == "TIMESTAMP" then                  r[k] = r[k] .. ' DEFAULT CURRENT_TIMESTAMP'              end + +            if options.NULL then +                r[k] = r[k] .. " NULL" +            else +                r[k] = r[k] .. " NOT NULL" +            end +                      end -        return r, alters, keys +        return r, alters, keys, extras      end, -    opentable = function(db, tablename, ddl) -        local fields, alters, keys = _luadb.generate_fields(db, ddl) -        local tname = "`" .. db._.prefix .. db.sql_escape(tablename) .. "`" +    opentable = function(db, tablename, ddl, force_create) +        local fields, alters, keys, extras = _luadb.generate_fields(db, tablename, ddl) +        local tname = db._.fq .. db._.prefix .. db.sql_escape(tablename) .. db._.fq          local operations = 0 +        local dfields = db._.Desc(db, tablename) -        if db:SafeQuery("DESC " .. tname) ~= 0 then +        if force_create or dfields == nil then              -- table doesn't exist, create it -            if db._.conn:ErrNO() == 1146 then -                local q = "CREATE TABLE " .. tname .. " (" -                local k, v, first +            local q = "CREATE TABLE " .. tname .. " (" +            local k, v, first +            first = true +            operations = -1 +             +            for k, v in pairs(fields) do +                if not first then +                    q = q .. ", " +                else +                    first = false +                end +                q = q .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v +            end +             +            if #keys ~= 0 then +                q = q .. ", PRIMARY KEY ("                  first = true -                operations = -1 -                 -                for k, v in pairs(fields) do +                for k, v in ipairs(keys) do                      if not first then                          q = q .. ", "                      else                          first = false                      end -                    q = q .. "`" .. db.sql_escape(k) .. "` " .. v +                    q = q .. db._.fq .. db.sql_escape(v) .. db._.fq                  end -                 -                if #keys ~= 0 then -                    q = q .. ", PRIMARY KEY (" -                    first = true -                    for k, v in ipairs(keys) do -                        if not first then -                            q = q .. ", " -                        else -                            first = false -                        end -                        q = q .. "`" .. db.sql_escape(v) .. "`" +                q = q .. ")" +            end +             +            q = q .. ")" +             +            if db._.odbc_name == "mysql" then +                q = q .. " ENGINE=InnoDB" +            end +             +            q = q .. ";" +             +            if db:SafeQuery(q, "update") ~= 0 then +                error("Error creating table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) +            end +             +            for k, v in pairs(alters) do +                local _, stmt +                for _, stmt in pairs(v) do +                    q = "ALTER TABLE " .. tname .. " " .. string.gsub(stmt, "@fieldname@", db.sql_escape(k)) +                    if db:SafeQuery(q, "update") ~= 0 then +                        error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q) +                        db:SafeQuery("DROP TABLE " .. tname .. ";", "update")                      end -                    q = q .. ")" -                end -                 -                q = q .. ") ENGINE=InnoDB;" -                 -                if db:SafeQuery(q) ~= 0 then -                    error("Error creating table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q)                  end -                 -                for k, v in pairs(alters) do -                    local _, stmt -                    for _, stmt in pairs(v) do -                        q = "ALTER TABLE " .. tname .. " " .. string.gsub(stmt, "@fieldname@", db.sql_escape(k)) -                        if db:SafeQuery(q) ~= 0 then -                            error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) -                            db:SafeQuery("DROP TABLE " .. tname ";") -                        end -                    end +            end + +            for k, v in pairs(extras) do +                if db:SafeQuery(v, "raw") ~= 0 then +                    error("Error running extra query for " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. v) +                    db:SafeQuery("DROP TABLE " .. tname .. ";", "update")                  end -            else -                error("Error getting description of table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error())              end          else              -- table exists, let's check it. -            local dfields = {}              local i              local any_common = false -            for i = 1, db._.conn:NumRows() do -                dfields[i] = db._.conn:FetchRow() +            for i = 1, #dfields do                  if fields[dfields[i].Field] then                      any_common = true                  end              end              if not any_common then -                -- strictly no field in common; drop the table and recursively call opentable in order to proceed with a createtable instead. +                -- strictly no field in common; drop the table and tail-call opentable in order to proceed with a createtable instead.                  q = "DROP TABLE " .. tname .. ";" -                if db:SafeQuery(q) ~= 0 then -                    error("Error dropping table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                if db:SafeQuery(q, "update") ~= 0 then +                    error("Error dropping table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q)                  end -                return db:opentable(tablename, ddl) +                return db:opentable(tablename, ddl, true)              end              local d, k, v = {} @@ -244,9 +284,9 @@ _luadb = {              for k, v in pairs(dfields) do                  d[v.Field] = v                  if not fields[v.Field] then -                    q = "ALTER TABLE " .. tname .. " DROP COLUMN `" .. db.sql_escape(v.Field) .. "`;" -                    if db:SafeQuery(q) ~= 0 then -                        error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                    q = "ALTER TABLE " .. tname .. " DROP COLUMN " .. db._.fq .. db.sql_escape(v.Field) .. db._.fq .. ";" +                    if db:SafeQuery(q, "update") ~= 0 then +                        error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q)                      end                      operations = operations + 1                  end @@ -258,11 +298,11 @@ _luadb = {              for k, v in pairs(fields) do                  q = nil                  if not dfields[k] then -                    q = "ALTER TABLE " .. tname .. " ADD `" .. db.sql_escape(k) .. "` " .. v .. ";" +                    q = "ALTER TABLE " .. tname .. " ADD " .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v .. ";"                  else -                    local identicals, alters = _luadb.compare_fields(ddl, dfields[k]) +                    local identicals, alters = _luadb.compare_fields(db, ddl, dfields[k])                      if not identicals then -                        q = "ALTER TABLE " .. tname .. " MODIFY `" .. db.sql_escape(k) .. "` " .. v .. ";" +                        q = "ALTER TABLE " .. tname .. " MODIFY " .. db._.fq .. db.sql_escape(k) .. db._.fq .. " " .. v .. ";"                          local _, a                          for _, a in pairs(alters) do                              a.stmt = string.gsub(a.stmt, "@fieldname@", db.sql_escape(k)) @@ -270,8 +310,8 @@ _luadb = {                                  table.insert(deffered_alters, a)                              else                                  q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";" -                                if db:SafeQuery(q) ~= 0 then -                                    error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                                if db:SafeQuery(q, "update") ~= 0 then +                                    error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q)                                  end                                  operations = operations + 1                              end @@ -279,8 +319,8 @@ _luadb = {                      end                  end                  if q then -                    if db:SafeQuery(q) ~= 0 then -                        error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                    if db:SafeQuery(q, "update") ~= 0 then +                        error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q)                      end                      operations = operations + 1                  end @@ -290,8 +330,8 @@ _luadb = {              for k, v in pairs(deffered_alters) do                  q = "ALTER TABLE " .. tname .. " " .. v.stmt .. ";" -                if db:SafeQuery(q) ~= 0 then -                    error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                if db:SafeQuery(q, "update") ~= 0 then +                    error("Error altering table " .. tname .. ": " .. db:ErrorStr() .. " - query run = " .. q)                  end                  operations = operations + 1              end @@ -325,7 +365,7 @@ _luadb = {          }, operations      end, -    compare_fields = function(ddl, field) +    compare_fields = function(db, ddl, field)          local f = field.Field          local identical = true          local alters = {} @@ -333,7 +373,7 @@ _luadb = {              error("Internal issue with the DDL: key " .. f .. " isn't a table.")          end          -- check base type first. -        local ddltype = _luadb.get_canon_type(ddl[f].type, f) +        local ddltype = _luadb.get_canon_type(ddl[f].type, f, db)          local desctype, desclength, descdecim = field.Type:match "(%w+)%((%d+)%,(%d+)%)"          if desctype == nil then              desctype, desclength = field.Type:match "(%w+)%((%d+)%)" @@ -344,7 +384,7 @@ _luadb = {          if desctype == nil then              error("Error parsing type string from database: " .. f .. ": " .. field.Type)          end -        desctype = _luadb.get_canon_type(desctype, f) +        desctype = _luadb.get_canon_type(desctype, f, db)          if ddltype ~= desctype then              identical = false @@ -372,24 +412,32 @@ _luadb = {              table.insert(alters, { stmt = "DROP PRIMARY KEY", pri = -2 } )          end -        if options.UNI and field.Key ~= "UNI" then -            -- needs to add a unique index -            identical = false -            table.insert(alters, { stmt = "ADD UNIQUE(@fieldname@)", pri = 0 } ) -        end +        -- Needs oracle version of this... +        if db._.odbc_name == "mysql" then +            if options.UNI and field.Key ~= "UNI" then +                -- needs to add a unique index +                identical = false +                table.insert(alters, { stmt = "ADD UNIQUE(@fieldname@)", pri = 0 } ) +            end -        if not options.UNI and field.Key == "UNI" then -            -- needs to drop the unique index -            identical = false -            table.insert(alters, { stmt = "DROP INDEX @fieldname@", pri = 0 } ) +            if not options.UNI and field.Key == "UNI" then +                -- needs to drop the unique index +                identical = false +                table.insert(alters, { stmt = "DROP INDEX @fieldname@", pri = 0 } ) +            end          end -        if not options.AUTO and field.Extra == "auto_increment" then -            identical = false +        -- Need oracle version of this check. +        if db._.odbc_name == "mysql" then +            if not options.AUTO and field.Extra == "auto_increment" then +                identical = false +            end          end -        if options.AUTO and field.Extra ~= "auto_increment" then -            identical = false +        if db._.odbc_name == "mysql" then +            if options.AUTO and field.Extra ~= "auto_increment" then +                identical = false +            end          end          return identical, alters @@ -398,6 +446,7 @@ _luadb = {      insert = function(t, data)          local k, v, stmt, stmt2          local first = true +        local got_blobs = false          stmt = "INSERT INTO " .. t._.tname .. " ("          stmt2 = ") VALUES (" @@ -406,60 +455,139 @@ _luadb = {              if not t._.ddl[k] then                  error("Trying to address a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename)              end -            if not first then -                stmt = stmt .. ", " -                stmt2 = stmt2 .. ", " +            local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) +            if t._.db._.odbc_name == "oracle" and (ftype == "BLOB" or ftype == "TEXT") then +                got_blobs = true              else -                first = false +                if not first then +                    stmt = stmt .. ", " +                    stmt2 = stmt2 .. ", " +                else +                    first = false +                end +                stmt = stmt .. t._.db._.fq .. k .. t._.db._.fq +                stmt2 = stmt2 .. t._.db._.eq .. v .. t._.db._.eq +            end +        end + +        if t._.db._.odbc_name == "oracle" then +            local conn, stmt, status = t._.db._.conn +            stmt = conn:createStatement(stmt .. stmt2 .. ") RETURNING ROWIDTOCHAR(rowid) INTO :1") +            stmt:registerOutParam(1, OCCISTRING, 30, "VARCHAR2") +            status = stmt:executeUpdate() +            if status ~= 1 then +                error("Error inserting a row inside table " .. t._.tablename .. ": " .. stmt:getErrorMsg()) +            end +            t._.db._.lastId = stmt:getString(1) +            conn:terminateStatement(stmt) +        else +            if t._.db:SafeQuery(stmt .. stmt2 .. ");", "update") ~= 0 then +                error("Error inserting a row inside table " .. t._.tablename .. ": " .. t._.db:Error())              end -            stmt = stmt .. "`" .. k .. "`" -            stmt2 = stmt2 .. "'" .. v .. "'"          end -        if t._.db:SafeQuery(stmt .. stmt2 .. ");") ~= 0 then -            error("Error inserting a row inside table " .. t._.tablename .. ": " .. t._.db:Error()) +        if t._.db._.odbc_name == "mysql" then +            t._.db._.lastId = t._.db._.conn:InsertId()          end -        return t._.db._.conn:InsertId() +        return t._.db._.lastId      end,      delete = function(t)          local stmt = "DELETE FROM " .. t._.tname .. " WHERE " .. t._.conditions .. ";" -        if t._.db:SafeQuery(stmt) ~= 0 then +        if t._.db:SafeQuery(stmt, "update") ~= 0 then              error("Error deleting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error())          end -        return t._.db._.conn:NumAffectedRows() +        if t._.db._.odbc_name == "mysql" then +            t._.db._.lastAffectedRows = t._.db._.conn:NumAffectedRows() +        elseif t._.db._.odbc_name == "oracle" then +            t._.db._.lastAffectedRows = t._.db._.affectedRows +        else +            error("Don't know how to handle odbc " .. t._.db._.odbc_name .. " in update.") +        end +         +        return t._.db._.lastAffectedRows      end,      update = function(t, data)          local stmt = "UPDATE " .. t._.tname .. " SET "          local first = true +        local got_blobs = false          local k, v          for k, v in pairs(data) do              if not t._.ddl[k] then                  error("Trying to update a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename)              end -            if not first then -                stmt = stmt .. ", " -            else -                first = false -            end -            stmt = stmt .. "`" .. k .. "`=" -            if type(v) == "string" or type(v) == "number" then -                stmt = stmt .. '"' .. t._.db.sql_escape(v) .. '"' +            local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) +            if t._.db._.odbc_name == "oracle" and (ftype == "BLOB" or ftype == "TEXT") then +                got_blobs = true              else -                error("Complex UPDATE queries are not supported yet.") +                if not first then +                    stmt = stmt .. ", " +                else +                    first = false +                end +                stmt = stmt .. t._.db._.fq .. t._.db.sql_escape(k) .. t._.db._.fq .. "=" +                if type(v) == "string" or type(v) == "number" then +                    stmt = stmt .. t._.db._.eq .. t._.db.sql_escape(v) .. t._.db._.eq +                else +                    error("Complex UPDATE queries are not supported yet.") +                end              end          end          stmt = stmt .. " WHERE " .. t._.conditions -        if t._.db:SafeQuery(stmt) ~= 0 then +        if t._.db:SafeQuery(stmt, "update", got_blobs) ~= 0 then              error("Error updating row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error())          end -        return t._.db._.conn:NumAffectedRows() +        if t._.db._.odbc_name == "mysql" then +            t._.db._.lastAffectedRows = t._.db._.conn:NumAffectedRows() +        elseif t._.db._.odbc_name == "oracle" then +            t._.db._.lastAffectedRows = t._.db._.affectedRows +        else +            error("Don't know how to handle odbc " .. t._.db._.odbc_name .. " in update.") +        end +         +        if got_blobs then +            -- very ugly way to solve this, but... *shrug* +            local conn, buffs, count, stmt, rset = t._.db._.conn, {}, 1 +            stmt = "SELECT " +            first = false +            for k, v in pairs(data) do +                local ftype = _luadb.get_canon_type(t._.ddl[k].type, k, t._.db) +                if ftype == "BLOB" or ftype == "TEXT" then +                    if not first then +                        stmt = stmt .. ", " +                    else +                        first = false +                    end +                    stmt = stmt .. t._.db._.fq .. t._.db.sql_escape(k) .. t._.db._.fq +                end +                buffs[count] = Buffer(true) +                buffs[count]:write(v) +                count = count + 1 +            end +            stmt = stmt .. " FROM " .. t._.tname .. " WHERE " .. t._.conditions .. " FOR UPDATE" +            stmt = conn:createStatement(stmt) +            rset = stmt:executeQuery() +            while rset:next() ~= 0 do +                for k, v in ipairs(buffs) do +                    rset:setBlob(k, v) +                    v:seek(0, SEEK_SET) +                end +            end +            for k, v in ipairs(buffs) do +                v:destroy() +            end +            conn:commit() +            stmt:closeResultSet(rset) +            conn:terminateStatement(stmt) +        end +         +        return t._.db._.lastAffectedRows      end,      iselect = function(bypass, t, fields, foreign, options) @@ -491,7 +619,7 @@ _luadb = {                  if bypass then  		    stmt = stmt .. v                  else -		    stmt = stmt .. "`" .. v .. "`" +		    stmt = stmt .. t._.db._.fq .. v .. t._.db._.fq  		end              end          end @@ -508,7 +636,7 @@ _luadb = {                          if not v._.ddl[f.foreign.fieldname] then                              error("Foreign key points to an unknown field in table: " .. f.foreign.tablename .. "." .. f.foreign.fieldname)                          end -                        foreign_conds = foreign_conds .. " AND " .. t._.tname .. ".`" .. fname .. "`=" .. v._.tname .. ".`" .. f.foreign.fieldname .. "`" +                        foreign_conds = foreign_conds .. " AND " .. t._.tname .. "." .. t._.db._.fq .. fname .. t._.db._.fq "=" .. v._.tname .. "." .. t._.db._.fq .. f.foreign.fieldname .. t._.db._.fq                          found_table = true                      end                  end @@ -542,7 +670,7 @@ _luadb = {                          error("options.order.dir has to be either 'asc' or 'desc'")                      end                  end -                stmt = stmt .. " ORDER BY `" .. options.order.by .. "` " .. direction +                stmt = stmt .. " ORDER BY " .. t._.db._.fq .. options.order.by .. t._.db._.fq .. " " .. direction              end              if options.limit then                  if type(options.limit) ~= "table" then @@ -561,21 +689,27 @@ _luadb = {                      end                      size = options.limit.size                  end -                stmt = stmt .. " LIMIT " .. start .. "," .. size +                if t._.db._.odbc_name == "mysql" then +                    stmt = stmt .. " LIMIT " .. start .. "," .. size +                elseif t._.db._.odbc_name == "oracle" then +                    error("LIMIT code for Oracle not implemented yet.") +                else +                    error("No LIMIT code - unknown odbc") +                end              end          end          stmt = stmt .. ";" -        if t._.db:SafeQuery(stmt) ~= 0 then -            error("Error selecting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) +        if t._.db:SafeQuery(stmt, "query") ~= 0 then +            error("Error selecting row(s) inside table " .. t._.tablename .. ": " .. t._.db:ErrorStr())          end -        return t:numrows(), t:numfields() +        return (t._.db._.odbc_name == "oracle" and -1 or t:numrows()), t:numfields()      end,      empty = function(t, fields) -        if t._.db:SafeQuery("TRUNCATE TABLE " .. t._.tname) ~= 0 then +        if t._.db:SafeQuery("TRUNCATE TABLE " .. t._.tname, "update") ~= 0 then              error("Error emptying table " .. t._.tablename .. ": " .. t._.db:Error())          end      end, @@ -585,16 +719,29 @@ _luadb = {      end,      numrows = function(t) -	return t._.db._.conn:NumRows() +        if t._.db._.odbc_name == "mysql" then +	    return t._.db._.conn:NumRows() +	elseif t._.db._.odbc_name == "oracle" then +	    error "No 'numrows' for Oracle yet." +	else +	    error "Unknown odbc for numrows." +	end      end,      numfields = function(t) -	return t._.db._.conn:NumFields() +        if t._.db._.odbc_name == "mysql" then +	    return t._.db._.conn:NumFields() +	elseif t._.db._.odbc_name == "oracle" then +	    return t._.db._.numFields +	else +	    error "Unknown odbc for numfields." +	end      end,      count = function(t)  	_luadb.iselect(true, t, {"COUNT(*) AS count"}) -	return t:nextrow().count +	local row = t:nextrow() +	return row.count or row.COUNT      end,      gselect = function(t, ...) @@ -603,7 +750,20 @@ _luadb = {      end,      nextrow = function(t) -        return t._.db._.conn:FetchRow() +        if t._.db._.odbc_name == "mysql" then +            return t._.db._.conn:FetchRow() +        elseif t._.db._.odbc_name == "oracle" then +            local row = nil +            if t._.db._.rset:next() ~= 0 then +                row = {} +                for k, v in t._.db._.fieldsNames do +                    row[v] = t._.db._.rset:getString(k) +                end +            end +            return row +        else +            error "Unknown odbc for nextrow." +        end      end,      _restricts = { @@ -615,7 +775,7 @@ _luadb = {              end              if type(expr) == "string" or type(expr) == "number" then -                return r._.tbl._.tname .. ".`" .. field .. "` " .. op .. ' "' .. r._.tbl._.db.sql_escape(expr) .. '"' +                return r._.tbl._.tname .. "." .. r._.tbl._.db._.fq .. field .. r._.tbl._.db._.fq .. " " .. op .. " " .. r._.tbl._.db._.eq .. r._.tbl._.db.sql_escape(expr) .. r._.tbl._.db._.eq              elseif type(expr) == "table" then                  error("Complex queries not handled for now")              else @@ -671,24 +831,219 @@ _luadb = {          t._.conditions = conditions      end, -    SafeQuery = function(db, str) -        local r = db._.conn:Query(str) -        if luadb.debugmode then -            print(str) -        end +    mysql = { +        SafeQuery = function(db, str, query_type, no_commit) +            if luadb.debugmode then +                print(str) +            end -        if r ~= 0 and db._.conn:ErrNO() == 2006 then -- disconnected -            db._.conn = luadb:opendb(db)._.conn -            r = db._.conn:Query(str) -        end +            local r = db._.conn:Query(str) +            if r ~= 0 and db._.conn:ErrNO() == 2006 then -- disconnected +                db._.conn = luadb:opendb(db)._.conn +                r = db._.conn:Query(str) +            end -        return r -    end, +            return r +        end, +         +        Desc = function(db, tbl)             +            if db:SafeQuery("DESC " .. db._.fq .. db._.prefix .. db.sql_escape(tbl) .. db._.fq .. ";") ~= 0 then return nil end +             +            local dfields = {} +            for i = 1, db._.conn:NumRows() do +                dfields[i] = db._.conn:FetchRow() +            end +             +            return dfields +        end, +    }, +    oracle = { +        _SafeQuery = function(db, str, query_type, no_commit) +            local r = -1 +         +            if luadb.debugmode then +                print(str) +            end +         +            if not query_type then +                error "SafeQuery: Need a query type" +            end +             +            if db._.stmt then +                if db._.rset then +                    db._.stmt:closeResultSet(db._.rset) +                    db._.rset = nil +                end +                db._.conn:terminateStatement(db._.stmt) +            end +         +            local stmt = db._.conn:createStatement() +             +            if query_type == "update" or query_type == "raw" then +                if query_type == "update" then +                    str = str:gsub(" *(.*) *;", "%1") +                end +                db._.affectedRows = stmt:executeUpdate(str) +                if db._.affectedRows == nil then +                    db._.errNO = stmt:getErrorCode() +                    db._.errStr = stmt:getErrorMsg() +                    db._.numFields = 0 +                    db._.fieldsNames = {} +                    db._.rset = nil +                else +                    db._.errNO = 0 +                    db._.errStr = "No error" +                    db._.numFields = 0 +                    db._.fieldsNames = {} +                    db._.rset = nil +                    r = 0 +                end +            elseif query_type == "query" then +                str = str:gsub(" *(.*) *;", "%1") +                local rset = stmt:executeQuery(str) +                if rset then +                    db._.errNO = 0 +                    db._.errStr = "No error" +                    db._.numFields, db._.fieldsNames = rset:getFieldsInfo() +                    db._.rset = rset +                    r = 0 +                else +                    db._.errNO = stmt:getErrorCode() +                    db._.errStr = stmt:getErrorMsg() +                    db._.numFields = 0 +                    db._.fieldsNames = {} +                    db._.rset = nil +                end +            elseif type(query_type) == "string" then +                error("Unknown query type: " .. query_type) +            else +                error("Query type mis-specified (not a string)") +            end +         +            db._.stmt = stmt +             +            if r == 0 and not no_commit then +                db._.conn:commit() +            end +         +            return r +        end, +         +        SafeQuery = function(db, str, query_type, no_commit) +            local r = _luadb.oracle._SafeQuery(db, str, query_type, no_commit) +            if r ~= 0 then +                -- connection timeout, let's reconnect. +                if db._.errNO == 3135 then +                    if db._.stmt then +                        if db._.rset then +                            db._.stmt:closeResultSet(db._.rset) +                            db._.rset = nil +                        end +                        db._.conn:terminateStatement(db._.stmt) +                    end +                    db._.conn = luadb.oracle.opendb(db)._.conn +                    r = _luadb.oracle._SafeQuery(db, str, query_type, no_commit) +                end +            end +             +            return r +        end, +         +        oracle_types = { +            [1] = "VARCHAR", +            [2] = "NUMBER", +            [12] = "DATE", +            [96] = "CHAR", +            [180] = "TIMESTAMP", +            [113] = "BLOB", +            [112] = "CLOB", +        }, +         +        Desc = function(db, tbl) +            if db:SafeQuery("SELECT DBMS_METADATA.GET_XML('TABLE', '" .. db.sql_escape(tbl) .. "'" .. ((db._.prefix and db._.prefix ~= "") and (", '" .. db.sql_escape(db._.prefix) .. "'") or "") .. ") FROM USER_TABLES;", "query") ~= 0 then +                if db:ErrNO() == 31603 then +                    return nil +                end +                error("Couldn't properly read the description of table " .. tbl .. " - error: " .. db:ErrorStr()) +            end +            local n = db._.rset:next() +            if n == 0 or not n then +                error("Error retrieving table description: no resultset available.") +            end +            local ddl = flatten_xml(xml.LoadHandle(db._.rset:getClob(1))) +            db._.stmt:closeResultSet(db._.rset) +            db._.rset = nil +            local dfields, i = {} +            local nfields = {} +            local col_list = ddl.ROWSET.ROW.TABLE_T.COL_LIST.COL_LIST_ITEM +            if not col_list.__complex then +                col_list = { col_list } +            end +            for i = 1, #col_list do +                local col = col_list[i] +                local row = { +                    Field = col.NAME, +                    Null = (col.NOT_NULL + 0) == 0 and "YES" or "NO", +                    Default = col.DEFAULT_VAL and col.DEFAULT_VAL:gsub("'(.-)' *", "%1") or "", +                    Type = _luadb.oracle.oracle_types[col.TYPE_NUM + 0], +                    Key = "", +                    Extra = "", +                } +                if not row.Type then +                    error("Unknown TYPE_NUM(" .. col.TYPE_NUM .. ") for column " .. row.Field) +                end +                 +                if col.PRECISION_NUM and col.SCALE then +                    row.Type = row.Type .. "(" .. col.PRECISION_NUM .. "," .. col.SCALE .. ")" +                elseif col.PRECISION_NUM then +                    row.Type = row.Type .. "(" .. col.PRECISION_NUM .. ")" +                else +                    if row.Type == "NUMBER" and (col.SCALE + 0) == 0 then +                        row.Type = "INT" +                    end +                    row.Type = row.Type .. "(" .. col.LENGTH .. ")" +                end +                 +                table.insert(dfields, row) +                nfields[row.Field] = row +            end +             +            -- need to add prefix/owner... +            if db:SafeQuery("SELECT USER_CONS_COLUMNS.COLUMN_NAME FROM USER_CONSTRAINTS, USER_CONS_COLUMNS WHERE USER_CONSTRAINTS.TABLE_NAME='" .. db.sql_escape(tbl) .. "' AND CONSTRAINT_TYPE='P' AND USER_CONSTRAINTS.CONSTRAINT_NAME = USER_CONS_COLUMNS.CONSTRAINT_NAME;", "query") ~= 0 then +                error("Unable to retrieve " .. table .. "'s primary keys - error: " .. db:ErrorStr()) +            end +            while db._.rset:next() ~= 0 do +                nfields[db._.rset:getString(1)].Key = "PRI" +            end +             +            db._.stmt:closeResultSet(db._.rset) +            db._.rset = nil + +            -- need to add prefix/owner... +            if db:SafeQuery("SELECT SEQUENCE_NAME FROM USER_SEQUENCES;", "query") ~= 0 then +                error("Unable to retrieve sequences - error: " .. db:ErrorStr()) +            end +             +            local pattern = "^seq_" .. db.sql_escape(tbl) .. "_(%w+)" +             +            while db._.rset:next() ~= 0 do +                local seq_name = db._.rset:getString(1) +                if seq_name:match(pattern) then +                    nfields[(db._.rset:getString(1):gsub(pattern, "%1"))].Extra = "auto_increment" +                end +            end +             +            db._.stmt:closeResultSet(db._.rset) +            db._.rset = nil + +            return dfields +        end +    },  }  luadb = { -    opendb = function(id, user, password, base, prefix) +    mysql = { opendb = function(id, user, password, base, prefix)          local db  --        if user == nil and password == nil and base == nil then @@ -705,18 +1060,67 @@ luadb = {          db = SQLConnection(id, user, password, base)          return {              opentable = _luadb.opentable, -            SafeQuery = _luadb.SafeQuery, +            SafeQuery = _luadb.mysql.SafeQuery,              sql_escape = sql_escape,              ErrNO = function(db) return db._.conn:ErrNO() end,              Error = function(db) return db._.conn:Error() end, +            ErrorStr = function(db) return db._.conn:ErrNO() .. " - " .. db._.conn:Error() end,              _ = { +                fq = '`', +                eq = '"',                  conn = db,                  id = id,                  user = user,                  password = password,                  base = base,                  prefix = sql_escape(prefix or ""), +                odbc_name = "mysql", +                Desc = _luadb.mysql.Desc,              },          } -    end, +    end, }, +     +    oracle = { opendb = function(id, user, password, base, prefix) +        local l_env, conn +         +        if _luadb.oracle.env then +            l_env = _luadb.oracle.env +        else +            l_env = createEnvironment() +            _luadb.oracle.env = l_env +            l_env:setExceptions(false) +        end +         +        if type(id) == "table" and type(id._) == "table" then +            user = id._.user +            password = id._.password +            base = id._.base +            prefix = id._.prefix +            id = id._.id +        end +        conn = l_env:createConnection(user, password, base) +        return { +            opentable = _luadb.opentable, +            SafeQuery = _luadb.oracle.SafeQuery, +            sql_escape = sql_escape, +            ErrNO = function(db) return db._.errNO end, +            Error = function(db) return db._.errStr end, +            ErrorStr = function(db) return db._.errStr end, +            _ = { +                fq = '"', +                eq = "'", +                conn = conn, +                id = id, +                user = user, +                password = password, +                base = base, +                prefix = sql_escape(prefix or ""), +                odbc_name = "oracle", +                Desc = _luadb.oracle.Desc, +            }, +        } +         +    end, },  } + +luadb.opendb = luadb.mysql.opendb  | 
