diff options
| author | Pixel <pixel@nobis-crew.org> | 2008-10-03 10:55:24 -0700 | 
|---|---|---|
| committer | Pixel <pixel@nobis-crew.org> | 2008-10-03 10:55:24 -0700 | 
| commit | 324855967ef23123abbd84fa4913bcb99fb5922a (patch) | |
| tree | f3b70fb18e0beddac9a8768beedf2e6e5e85f8c0 /lib | |
| parent | 314b7e319c1392841b07a0800ed76c48d36c74db (diff) | |
Adding dblib.lua
Diffstat (limited to 'lib')
| -rw-r--r-- | lib/dblib.lua | 519 | 
1 files changed, 519 insertions, 0 deletions
| diff --git a/lib/dblib.lua b/lib/dblib.lua new file mode 100644 index 0000000..6469d9c --- /dev/null +++ b/lib/dblib.lua @@ -0,0 +1,519 @@ +--[[ + +/* + *  Baltisot + *  Copyright (C) 1999-2008 Nicolas "Pixel" Noble + * + *  This program is free software; you can redistribute it and/or modify + *  it under the terms of the GNU General Public License as published by + *  the Free Software Foundation; either version 2 of the License, or + *  (at your option) any later version. + * + *  This program is distributed in the hope that it will be useful, + *  but WITHOUT ANY WARRANTY; without even the implied warranty of + *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the + *  GNU General Public License for more details. + * + *  You should have received a copy of the GNU General Public License + *  along with this program; if not, write to the Free Software + *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA + */ + +]]-- + +-- dblib - database object layer +-- Currently built-in with MySQL - TODO: support some abstraction layer with drivers. + +--[[ + +ddl = { +    field1 = { type = "varchar", length = 32 }, +    field2 = { type = "int", length = 11, decimal = 8, default = 42 }, +    field3 = { options = "pri,auto" }, +} + +]]-- + +--local _luadb +_luadb = { +    get_options = function(str) +        local s = split(str, ",") +        local r = {} +        for k, v in pairs(s) do +            r[v:upper()] = true +        end +         +        return r +    end, +     +    get_canon_type = function(ttype, k) +        if ttype == nil then +            return "INT" +        elseif type(ttype) ~= "string" then +            error("Wrong data in ddl - " .. k .. ".type isn't a string.") +        elseif ttype:upper() == "VARCHAR" then +            return "VARCHAR" +        elseif ttype:upper() == "INT" then +            return "INT" +        elseif ttype:upper() == "BLOB" then +            return "BLOB" +        elseif ttype:upper() == "DATETIME" then +            return "DATETIME" +        elseif ttype:upper() == "NUMBER" then +            return "NUMBER" +        else +            error("Unknow data type in ddl - " .. k .. ": " .. ttype:upper()) +        end             +    end, + +    generate_fields = function(ddl) +        local k, v +        local r, r2 = {}, {} +         +        for k, v in pairs(ddl) do +            r[k] = "" +            r2[k] = {} +             +            r[k] = _luadb.get_canon_type(v.type, k) +             +            if v.length == nil then +                -- do nothing +            elseif type(v.length) ~= "number" then +                error("Wrong data in ddl - " .. k .. ".length isn't a number.") +            elseif v.decimal == nil then +                r[k] = r[k] .. "(" .. v.length .. ")" +            elseif type(v.decimal) ~= "number" then +                error("Wrong data in ddl - " .. k .. ".decimal isn't a number.") +            else +                r[k] = r[k] .. "(" .. v.length .. "," .. v.decimal .. ")" +            end +             +            if v.options ~= nil and type(v.options) ~= "string" then +                error("Wrong data in ddl - " .. k .. ".options isn't a string.") +            end +             +            local options = v.options and _luadb.get_options(v.options) or {} +             +            if options.NULL then +                r[k] = r[k] .. " NULL" +            else +                r[k] = r[k] .. " NOT NULL" +            end +             +            if options.PRI then +                r[k] = r[k] .. " PRIMARY KEY" +            end +             +            if options.UNIQ then +                table.insert(r2[k], "ADD UNIQUE (`@fieldname@`)") +            end + +            if options.AUTO then +                r[k] = r[k] .. " AUTO_INCREMENT" +            end +        end +        return r, r2 +    end, +     +    opentable = function(db, tablename, ddl) +        local fields, alters = _luadb.generate_fields(ddl) +        local tname = "`" .. db._.prefix .. db.sql_escape(tablename) .. "`" +         +        if db:SafeQuery("DESC " .. tname) ~= 0 then +            -- table doesn't exist, create it +            if db._.conn:ErrNO() == 1146 then +                local q = "CREATE TABLE " .. tname .. " (" +                local k, v, first +                first = true +                 +                for k, v in pairs(fields) do +                    if not first then +                        q = q .. ", " +                    end +                    q = q .. "`" .. db.sql_escape(k) .. "` " .. v +                end +                 +                q = q .. ");" +                 +                if db:SafeQuery(q) ~= 0 then +                    error("Error creating table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                end +                 +                for k, v in pairs(alters) do +                    local _, stmt +                    for _, stmt in pairs(v) do +                        q = "ALTER TABLE " .. tname .. " " .. string.gsub(stmt, "@fieldname@", db.sql_escape(k)) +                        if db:SafeQuery(q) ~= 0 then +                            error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                            db:SafeQuery("DROP TABLE " .. tname ";") +                        end +                    end +                end +            else +                error("Error getting description of table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error()) +            end +        else +            -- table exists, let's check it. +            local dfields = {} +            local i +            local any_common = false +             +            for i = 1, db._.conn:NumRows() do +                dfields[i] = db._.conn:FetchRow() +                if fields[dfields[i].Field] then +                    any_common = true +                end +            end +             +            if not any_common then +                -- strictly no field in common; drop the table and recursively call opentable in order to proceed with a createtable instead. +                q = "DROP TABLE " .. tname .. ";" +                if db:SafeQuery(q) ~= 0 then +                    error("Error dropping table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                end +                return db:opentable(tablename, ddl) +            end +             +            local d, k, v = {} +             +            for k, v in pairs(dfields) do +                d[v.Field] = v +                if not fields[v.Field] then +                    q = "ALTER TABLE " .. tname .. " DROP COLUMN `" .. db.sql_escape(v.Field) .. "`;" +                    if db:SafeQuery(q) ~= 0 then +                        error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                    end +                end +            end +            dfields, d = d, nil +             +            local deffered_alters = {} +             +            for k, v in pairs(fields) do +                q = nil +                if not dfields[k] then +                    q = "ALTER TABLE " .. tname .. " ADD `" .. db.sql_escape(k) .. "` " .. v .. ";" +                else +                    local identicals, alters = _luadb.compare_fields(ddl, dfields[k]) +                    if not identicals then +                        q = "ALTER TABLE " .. tname .. " MODIFY `" .. db.sql_escape(k) .. "` " .. v .. ";" +                        local _, a +                        for _, a in pairs(alters) do +                            a.stmt = string.gsub(a.stmt, "@fieldname@", db.sql_escape(k)) +                            if a.pri < 0 then +                                table.insert(deffered_alters, a) +                            else +                                q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";" +                                if db:SafeQuery(q) ~= 0 then +                                    error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                                end +                            end +                        end +                    end +                end +                if q then +                    if db:SafeQuery(q) ~= 0 then +                        error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                    end +                end +            end +             +            table.sort(deffered_alters, function(a, b) return a.pri < b.pri end) +             +            for k, v in pairs(deffered_alters) do +                q = "ALTER TABLE " .. tname .. " " .. a.stmt .. ";" +                if db:SafeQuery(q) ~= 0 then +                    error("Error altering table " .. tname .. ": " .. db:ErrNO() .. " - " .. db:Error() .. " - query run = " .. q) +                end +            end +        end +         +        return { +            insert = _luadb.insert, +            delete = _luadb.delete, +            restrict = _luadb.restrict, +            restricter = _luadb.restricter, +            _ = { +                tablename = tablename, +                tname = tname, +                db = db, +                ddl = ddl, +                conditions = "1", +            }, +        } +    end, + +    compare_fields = function(ddl, field) +        local f = field.Field +        local identical = true +        local alters = {} +        if type(ddl[f]) ~= "table" then +            error("Internal issue with the DDL: key " .. f .. " isn't a table.") +        end +        -- check base type first. +        local ddltype = _luadb.get_canon_type(ddl[f].type, f) +        local desctype, desclength, descdecim = field.Type:match "(%w+)%((%d+)%.(%d+)%)" +        if desctype == nil then +            desctype, desclength = field.Type:match "(%w+)%((%d+)%)" +        end +        if desctype == nil then +            error("Error parsing type string from database: " .. f .. ": " .. field.Type) +        end +        desctype = _luadb.get_canon_type(desctype, f) +         +        if ddltype ~= desctype then +            identical = false +        end +         +        if field.length and field.length ~= desclength then +            identical = false +        end +         +        if field.decimal and field.decimal ~= descdecim then +            identical = false +        end + +        local options = ddl[f].options and _luadb.get_options(ddl[f].options) or {} +         +        if options.PRI and field.Key ~= "PRI" then +            -- needs to add a primary key here, with a specific ALTER TABLE statement +            identical = false +            table.insert(alters, { stmt = "ADD PRIMARY KEY(@fieldname@)", pri = -2 } ) +        end +         +        if not options.PRI and field.Key == "PRI" then +            -- needs to drop the primary key here. +            identical = false +            table.insert(alters, { stmt = "DROP PRIMARY KEY", pri = -1 } ) +        end +         +        if options.UNI and field.Key ~= "UNI" then +            -- needs to add a unique index +            identical = false +            table.insert(alters, { stmt = "ADD UNIQUE(@fieldname@)", pri = 0 } ) +        end +         +        if not options.UNI and field.Key == "UNI" then +            -- needs to drop the unique index +            identical = false +            table.insert(alters, { stmt = "DROP INDEX @fieldname@", pri = 0 } ) +        end +         +        if not options.AUTO and field.Extra == "auto_increment" then +            identical = false +        end +         +        if options.AUTO and field.Extra ~= "auto_increment" then +            identical = false +        end +         +        return identical, alters +    end, +     +    insert = function(t, data) +        local k, v, stmt, stmt2 +        local first = true +         +        stmt = "INSERT INTO " .. t._.tname .. " (" +        stmt2 = ") VALUES (" +         +        for k, v in pairs(data) do +            if not t._.ddl[k] then +                error("Trying to address a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) +            end +            if not first then +                stmt = stmt .. ", " +                stmt2 = stmt2 .. ", " +            else +                first = false +            end +            stmt = stmt .. "`" .. k .. "`" +            stmt2 = stmt2 .. "'" .. v .. "'" +        end +         +        if t._.db:SafeQuery(stmt .. stmt2 .. ");") ~= 0 then +            error("Error inserting a row inside table " .. t._.tablename .. ": " .. t._.db:Error()) +        end +         +        return t._.db._.conn:InsertId() +    end, +     +    delete = function(t) +        local stmt = "DELETE FROM " .. t._.tname .. " WHERE " .. t._.conditions .. ";" +         +        if t._.db:SafeQuery(stmt) ~= 0 then +            error("Error deleting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) +        end         +    end, +     +    update = function(t, data) +        local stmt = "UPDATE " .. t._.tname .. " SET " +        local first = true +        local k, v +         +        for k, v in pairs(data) do +            if t._.ddl[k] then +                error("Trying to update a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) +            end +            if not first then +                stmt = stmt .. ", " +            else +                first = false +            end +            stmt = stmt .. "`" .. k .. "`=" +            if type(v) == "string" then +                stmt = stmt .. '"' .. t._.db.sql_escape(v) .. '"' +            else +                error("Complex UPDATE queries are not supported yet.") +            end +        end +         +        stmt = stmt .. " WHERE " .. t._.conditions +         +        if t._.db:SafeQuery(stmt) ~= 0 then +            error("Error updating row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) +        end         +    end, +     +    select = function(t, fields) +        local stmt = "SELECT " +        local first = true +        local k, v +         +        for k, v in pairs(fields) do +            if t._.ddl[v] then +                error("Trying to select a field which doesn't exist: " .. k .. " - inside of table " .. t._.tablename) +            end +            if not first then +                stmt = stmt .. ", " +            else +                first = false +            end +            stmt = stmt .. "`" .. v .. "`" +        end +         +        stmt = stmt .. " FROM " .. t._.tname .. " WHERE " .. t._.conditions +         +        if t._.db:SafeQuery(stmt) ~= 0 then +            error("Error selecting row(s) inside table " .. t._.tablename .. ": " .. t._.db:Error()) +        end +         +        return _luadb.nextrow, t +    end, +     +    nextrow = function(t) +        return t._.db._.conn:FetchRow() +    end, +     +    _restricts = { +        -- All the rendering is done on-the-fly. +        -- this is the leaf function; thus this is the only one which ough to get the sql_escape. +        dop = function(r, field, expr, op) +            if not r._.tbl._.ddl[field] then +                error("Trying to restrict on a field which doesn't exist: " .. field .. " - inside of table " .. r._.tbl._.tablename) +            end +             +            if type(expr) == "string" or type(expr) == "number" then +                return "`" .. field .. "` " .. op .. ' "' .. t._.db.sql_escape(expr) .. '"' +            elseif type(expr) == "table" then +                error("Complex queries not handled for now") +            else +                error("Wrong type for expression in operator: " .. type(expr)) +            end +        end, +         +        varops = function(r, op, ...) +            local exprs, first, r, k, v = {...}, true +             +            for k, v in pairs(exprs) do +                if not first then +                    r = r .. " " .. op .. " " +                end +                r = r .. v +                first = false +            end +        end, +    }, +     +    restricts = { +        NOT = function(r, expr) +            return "NOT " .. expr +        end, +         +        AND = function(r, ...) return _luadb._restricts.varops(r, "AND", ...) end, +        OR = function(r, ...) return _luadb._restricts.varops(r, "OR", ...) end, +         +        EQ = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "=") end, +        NEQ = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "!=") end, +        GT = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, ">") end, +        LT = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "<") end, +        GE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, ">=") end, +        LE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "<=") end, +        LIKE = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "LIKE") end, +        IN = function(r, field, expr) return _luadb._restricts.dop(r, field, expr, "IN") end, +    }, +     +    restricter = function(t) +        local r, k, v = { _ = { tbl = t } } +         +        for k, v in pairs(_luadb.restricts) do +            r[k] = v +        end +         +        return r +    end, +     +    restrict = function(t, conditions) +        -- all in all, "conditions" can just be a direct string. Let the programmer do what he wants, after all. +        if conditions == nil or conditions == "" then conditions = "1" end +        t._.conditions = conditions +    end, +} + +luadb = { +    SafeQuery = function(db, str) +        local r = db._.conn:Query(str) +        if luadb.debugmode then +            print(str) +        end +         +        if r ~= 0 and db._.conn:ErrNO() == 2006 then -- disconnected +            db._.conn = db:opendb() +            r = db._.conn:Query(str) +        end +         +        return r +    end, +     +    opendb = function(id, user, password, base, prefix) +        local db +         +--        if user == nil and password == nil and base == nil then +            -- get into a registry to fetch "id" +--        else +--        end +        if type(id) == "table" and type(id._) == "table" then +            user = id._.user +            password = id._.password +            base = id._.base +            prefix = id._.prefix +            id = id._.id +        end +        db = SQLConnection(id, user, password, base) +        return { +            opentable = _luadb.opentable, +            SafeQuery = luadb.SafeQuery, +            sql_escape = sql_escape, +            ErrNO = function(db) return db._.conn:ErrNO() end, +            Error = function(db) return db._.conn:Error() end, +            _ = { +                conn = db, +                id = id, +                user = user, +                password = password, +                base = base, +                prefix = sql_escape(prefix or ""), +            }, +        } +    end, +} | 
