Skip to content

Commit

Permalink
stable table keys iteration order by tools.sortedpairs helper
Browse files Browse the repository at this point in the history
  • Loading branch information
xHasKx committed Nov 25, 2023
1 parent e3ae985 commit 111f23a
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 4 deletions.
5 changes: 3 additions & 2 deletions mqtt/protocol.lua
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ local rshift = bit.rshift

local tools = require("mqtt.tools")
local div = tools.div
local sortedpairs = tools.sortedpairs

--- Create bytes of the uint8 value
-- @tparam number val - integer value to convert to bytes
Expand Down Expand Up @@ -489,7 +490,7 @@ local function value_tostring(value)
return str_format("%q", value)
elseif t == "table" then
local res = {}
for k, v in pairs(value) do
for k, v in sortedpairs(value) do
if type(k) == "number" then
res[#res + 1] = value_tostring(v)
else
Expand All @@ -511,7 +512,7 @@ end
-- @treturn string human-readable string representation of the packet
function protocol.packet_tostring(packet)
local res = {}
for k, v in pairs(packet) do
for k, v in sortedpairs(packet) do
res[#res + 1] = str_format("%s=%s", k, value_tostring(v))
end
return str_format("%s{%s}", tostring(packet_type[packet.type]), tbl_concat(res, ", "))
Expand Down
7 changes: 5 additions & 2 deletions mqtt/protocol5.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ local string = require("string")
local str_char = string.char
local fmt = string.format

local tools = require("mqtt.tools")
local sortedpairs = tools.sortedpairs

local bit = require("mqtt.bitwrap")
local bor = bit.bor
local band = bit.band
Expand Down Expand Up @@ -321,7 +324,7 @@ local function make_properties(ptype, args)
assert(type(args.properties) == "table", "expecting .properties to be a table")
-- validate all properties and append them to order list
local order = {}
for name, value in pairs(args.properties) do
for name, value in sortedpairs(args.properties) do
assert(type(name) == "string", "expecting property name to be a string: "..tostring(name))
-- detect property identifier and check it's allowed for that packet type
local prop_id = assert(properties[name], "unknown property: "..tostring(name))
Expand Down Expand Up @@ -360,7 +363,7 @@ local function make_properties(ptype, args)
assert(type(args.user_properties) == "table", "expecting .user_properties to be a table")
assert(allowed[uprop_id], "user_property is not allowed for packet type "..ptype)
local order = {}
for name, val in pairs(args.user_properties) do
for name, val in sortedpairs(args.user_properties) do
local ntype = type(name)
if ntype == "string" then
if type(val) ~= "string" then
Expand Down
41 changes: 41 additions & 0 deletions mqtt/tools.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ local str_byte = string.byte

local table = require("table")
local tbl_concat = table.concat
local tbl_sort = table.sort

local type = type
local error = error
local pairs = pairs

local math = require("math")
local math_floor = math.floor
Expand All @@ -29,6 +34,42 @@ function tools.div(x, y)
return math_floor(x / y)
end

-- table.sort callback for tools.sortedpairs()
local function sortedpairs_compare(a, b)
local a_type = type(a)
local b_type = type(b)
if (a_type == "string" and b_type == "string") or (a_type == "number" and b_type == "number") then
return a < b
elseif a_type == "number" then
return true
elseif b_type == "number" then
return false
else
error("sortedpairs failed to make a stable keys comparison of types "..a_type.." and "..b_type)
end
end

-- Iterate through table keys and values in stable (sorted) order
function tools.sortedpairs(tbl)
local keys = {}
for k in pairs(tbl) do
local k_type = type(k)
if k_type ~= "string" and k_type ~= "number" then
error("sortedpairs failed to make a stable iteration order for key of type "..type(k))
end
keys[#keys + 1] = k
end
tbl_sort(keys, sortedpairs_compare)
local i = 0
return function()
i = i + 1
local key = keys[i]
if key then
return key, tbl[key]
end
end
end

-- export module table
return tools

Expand Down
23 changes: 23 additions & 0 deletions tests/spec/module-basics.lua
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,29 @@ describe("MQTT lua library component test:", function()
assert.are.equal("FF00FF", tools.hex("\255\000\255"))
end)

it("tools.sortedpairs", function()
assert.are.equal(type(tools.sortedpairs), "function")

-- naive table-to-string implementation with stable table key iteration order
local function tbl_tostring(tbl)
local res = {"{"}
for k, v in tools.sortedpairs(tbl) do
res[#res + 1] = string.format("%s=%s,", k, v)
end
res[#res + 1] = "}"
return table.concat(res)
end

assert.are.equal("{}", tbl_tostring{})
assert.are.equal('{a=1,}', tbl_tostring{a=1,})
assert.are.equal('{a=1,b=2,}', tbl_tostring{b=2,a=1,})
assert.are.equal('{a=1,b=2,}', tbl_tostring{b=2,a=1,})
assert.are.equal('{1=1,2=2,3=3,}', tbl_tostring{1,2,3,})
assert.are.equal('{1=1,2=2,3=3,}', tbl_tostring{[3]=3,[2]=2,[1]=1,})
assert.are.equal('{1=1,a=1,}', tbl_tostring{1,a=1,})
assert.are.equal('{1=1,2=2,3=3,a=1,b=2,}', tbl_tostring{b=2,a=1,[3]=3,[2]=2,[1]=1,})
end)

it("extract_hex", function()
assert.are.equal("", extract_hex(""))
assert.are.equal("", extract_hex(" "))
Expand Down

0 comments on commit 111f23a

Please sign in to comment.