diff --git a/resources/install/scripts/resources/functions/database.lua b/resources/install/scripts/resources/functions/database.lua index 35c4e1c87f..4d2791c09e 100644 --- a/resources/install/scripts/resources/functions/database.lua +++ b/resources/install/scripts/resources/functions/database.lua @@ -19,6 +19,52 @@ BACKEND.main = BACKEND.main or 'native' local unpack = unpack or table.unpack +----------------------------------------------------------- + +local NULL, DEFAULT = {}, {} + +local param_pattern = "[:]([^%d%s][%a%d_]+)" + +-- +-- Substitude named parameters to query +-- +-- @tparam string sql query text +-- @tparam table params values for parameters +-- @treturn[1] string new sql query +-- @treturn[2] nil +-- @treturn[2] string error message +-- +local function apply_params(db, sql, params) + params = params or {} + + local str = string.gsub(sql, param_pattern, function(param) + local v, t = params[param], type(params[param]) + if "string" == t then return db:quote(v) end + if "number" == t then return tostring(v) end + if "boolean" == t then return v and '1' or '0' end + if NULL == v then return 'NULL' end + if DEFAULT == v then return 'DEFAULT' end + err = "undefined parameter: " .. param + end) + + if err then return nil, err end + + return str +end + +local sql_escape + +if freeswitch then + local api = require "resources.functions.api" + sql_escape = function(str) + return api:execute('sql_escape', str) + end +else + sql_escape = function(str) + return (string.gsub(str, "'", "''")) + end +end + ----------------------------------------------------------- local installed_classes = {} local default_backend = FsDatabase @@ -31,6 +77,10 @@ local function new_database(backend, backend_name) Database.__base = backend or default_backend Database = setmetatable(Database, Database.__base) + Database.NULL = NULL + + Database.DEFAULT = NULL + function Database.new(...) local self = Database.__base.new(...) setmetatable(self, Database) @@ -40,10 +90,52 @@ local function new_database(backend, backend_name) function Database:backend_name() return backend_name end - - function Database:first_row(sql) + + function Database:_apply_params(sql, params) + return apply_params(self, sql, params) + end + + function Database:query(sql, ...) + local params, callback + + local argc = select('#', ...) + + if argc > 0 then + local p = select(argc, ...) + if (p == nil) or (type(p) == 'function') then + callback = p + argc = argc - 1 + end + end + + if argc > 0 then + local p = select(argc, ...) + if (p == nil) or (type(p) == 'table') then + params = p + argc = argc - 1 + end + end + + assert(argc == 0, 'invalid argument #' .. tostring(argc)) + + if params then + -- backend supports parameters natively + if self.__base.parameter_query then + return self.__base.parameter_query(self, sql, params, callback) + end + + -- use emulation of parametes + local err + sql, err = self:_apply_params(sql, params) + if not sql then return nil, err end + end + + return self.__base.query(self, sql, callback) + end + + function Database:first_row(sql, params) local result - local ok, err = self:query(sql, function(row) + local ok, err = self:query(sql, params, function(row) result = row return 1 end) @@ -51,21 +143,33 @@ local function new_database(backend, backend_name) return result end - function Database:first_value(sql) - local result, err = self:first_row(sql) + function Database:first_value(sql, params) + local result, err = self:first_row(sql, params) if not result then return nil, err end local k, v = next(result) return v end function Database:first(sql, ...) - local result, err = self:first_row(sql) - if not result then return nil, err end - local t, n = {}, select('#', ...) - for i = 1, n do - t[i] = result[(select(i, ...))] + local t = type((...)) + local has_params = (t == 'nil') or (t == 'table') + + local result, err + if has_params then + result, err = self:first_row(sql, (...)) + else + result, err = self:first_row(sql) end - return unpack(t, 1, n) + + if not result then return nil, err end + + local t, n, c = {}, select('#', ...), 0 + for i = (has_params and 2 or 1), n do + c = c + 1 + t[c] = result[(select(i, ...))] + end + + return unpack(t, 1, c) end function Database:fetch_all(sql) @@ -78,7 +182,7 @@ local function new_database(backend, backend_name) end function Database:escape(str) - return (string.gsub(str, "'", "''")) + return sql_escape(str) end function Database:quote(str) @@ -176,6 +280,46 @@ local function new_database(backend, backend_name) db:release() assert(not db:connected()) + local db = Database.new(...) + + assert(db:connected()) + + -- test substitude parameters + t = assert(db:first_row('select :p1 as p1, :p2 as p2', {p1 = 'hello', p2 = 'world'})) + assert(t.p1 == 'hello') + assert(t.p2 == 'world') + + -- test escape string + -- `sql_escape` on freeswitch do `trim` + if not freeswitch then + -- test no trim value + local v = " hello " + a = assert(db:first_value('select :p1', {p1 = v})) + assert(a == v) + + -- test newline + -- On Windows with pgsql it replace `\n` to `\r\n`) + local v = "\r\nhello\r\nworld\r\n" + a = assert(db:first_value('select :p1', {p1 = v})) + assert(a == v, string.format('%q', tostring(a))) + end + + -- test backslash + local v = "\\hello\\world\\" + a = assert(db:first_value('select :p1', {p1 = v})) + assert(a == v, string.format('%q', tostring(a))) + + -- test single quote + local v = "'hello''world'''" + a = assert(db:first_value('select :p1', {p1 = v})) + assert(a == v, string.format('%q', tostring(a))) + + -- test empty string + local v = "" + a = assert(db:first_value('select :p1', {p1 = v})) + assert(a == v, string.format('%q', tostring(a))) + + db:release() log.info('self_test Database - pass') end