From 111f23af9227e42d9cb27f59a396cb53ecd4e14c Mon Sep 17 00:00:00 2001 From: Alexander Kiranov Date: Sat, 25 Nov 2023 12:03:12 +0200 Subject: [PATCH] stable table keys iteration order by tools.sortedpairs helper --- mqtt/protocol.lua | 5 +++-- mqtt/protocol5.lua | 7 ++++-- mqtt/tools.lua | 41 ++++++++++++++++++++++++++++++++++++ tests/spec/module-basics.lua | 23 ++++++++++++++++++++ 4 files changed, 72 insertions(+), 4 deletions(-) diff --git a/mqtt/protocol.lua b/mqtt/protocol.lua index 040ecfc..1ab62fa 100644 --- a/mqtt/protocol.lua +++ b/mqtt/protocol.lua @@ -56,6 +56,7 @@ local rshift = bit.rshift local tools = require("mqtt.tools") local div = tools.div +local sortedpairs = tools.sortedpairs --- Create bytes of the uint8 value -- @tparam number val - integer value to convert to bytes @@ -489,7 +490,7 @@ local function value_tostring(value) return str_format("%q", value) elseif t == "table" then local res = {} - for k, v in pairs(value) do + for k, v in sortedpairs(value) do if type(k) == "number" then res[#res + 1] = value_tostring(v) else @@ -511,7 +512,7 @@ end -- @treturn string human-readable string representation of the packet function protocol.packet_tostring(packet) local res = {} - for k, v in pairs(packet) do + for k, v in sortedpairs(packet) do res[#res + 1] = str_format("%s=%s", k, value_tostring(v)) end return str_format("%s{%s}", tostring(packet_type[packet.type]), tbl_concat(res, ", ")) diff --git a/mqtt/protocol5.lua b/mqtt/protocol5.lua index fb1ce76..8510aa4 100644 --- a/mqtt/protocol5.lua +++ b/mqtt/protocol5.lua @@ -26,6 +26,9 @@ local string = require("string") local str_char = string.char local fmt = string.format +local tools = require("mqtt.tools") +local sortedpairs = tools.sortedpairs + local bit = require("mqtt.bitwrap") local bor = bit.bor local band = bit.band @@ -321,7 +324,7 @@ local function make_properties(ptype, args) assert(type(args.properties) == "table", "expecting .properties to be a table") -- validate all properties and append them to order list local order = {} - for name, value in pairs(args.properties) do + for name, value in sortedpairs(args.properties) do assert(type(name) == "string", "expecting property name to be a string: "..tostring(name)) -- detect property identifier and check it's allowed for that packet type local prop_id = assert(properties[name], "unknown property: "..tostring(name)) @@ -360,7 +363,7 @@ local function make_properties(ptype, args) assert(type(args.user_properties) == "table", "expecting .user_properties to be a table") assert(allowed[uprop_id], "user_property is not allowed for packet type "..ptype) local order = {} - for name, val in pairs(args.user_properties) do + for name, val in sortedpairs(args.user_properties) do local ntype = type(name) if ntype == "string" then if type(val) ~= "string" then diff --git a/mqtt/tools.lua b/mqtt/tools.lua index 48b8886..cba3a90 100644 --- a/mqtt/tools.lua +++ b/mqtt/tools.lua @@ -10,6 +10,11 @@ local str_byte = string.byte local table = require("table") local tbl_concat = table.concat +local tbl_sort = table.sort + +local type = type +local error = error +local pairs = pairs local math = require("math") local math_floor = math.floor @@ -29,6 +34,42 @@ function tools.div(x, y) return math_floor(x / y) end +-- table.sort callback for tools.sortedpairs() +local function sortedpairs_compare(a, b) + local a_type = type(a) + local b_type = type(b) + if (a_type == "string" and b_type == "string") or (a_type == "number" and b_type == "number") then + return a < b + elseif a_type == "number" then + return true + elseif b_type == "number" then + return false + else + error("sortedpairs failed to make a stable keys comparison of types "..a_type.." and "..b_type) + end +end + +-- Iterate through table keys and values in stable (sorted) order +function tools.sortedpairs(tbl) + local keys = {} + for k in pairs(tbl) do + local k_type = type(k) + if k_type ~= "string" and k_type ~= "number" then + error("sortedpairs failed to make a stable iteration order for key of type "..type(k)) + end + keys[#keys + 1] = k + end + tbl_sort(keys, sortedpairs_compare) + local i = 0 + return function() + i = i + 1 + local key = keys[i] + if key then + return key, tbl[key] + end + end +end + -- export module table return tools diff --git a/tests/spec/module-basics.lua b/tests/spec/module-basics.lua index 95777a7..0321fc8 100644 --- a/tests/spec/module-basics.lua +++ b/tests/spec/module-basics.lua @@ -34,6 +34,29 @@ describe("MQTT lua library component test:", function() assert.are.equal("FF00FF", tools.hex("\255\000\255")) end) + it("tools.sortedpairs", function() + assert.are.equal(type(tools.sortedpairs), "function") + + -- naive table-to-string implementation with stable table key iteration order + local function tbl_tostring(tbl) + local res = {"{"} + for k, v in tools.sortedpairs(tbl) do + res[#res + 1] = string.format("%s=%s,", k, v) + end + res[#res + 1] = "}" + return table.concat(res) + end + + assert.are.equal("{}", tbl_tostring{}) + assert.are.equal('{a=1,}', tbl_tostring{a=1,}) + assert.are.equal('{a=1,b=2,}', tbl_tostring{b=2,a=1,}) + assert.are.equal('{a=1,b=2,}', tbl_tostring{b=2,a=1,}) + assert.are.equal('{1=1,2=2,3=3,}', tbl_tostring{1,2,3,}) + assert.are.equal('{1=1,2=2,3=3,}', tbl_tostring{[3]=3,[2]=2,[1]=1,}) + assert.are.equal('{1=1,a=1,}', tbl_tostring{1,a=1,}) + assert.are.equal('{1=1,2=2,3=3,a=1,b=2,}', tbl_tostring{b=2,a=1,[3]=3,[2]=2,[1]=1,}) + end) + it("extract_hex", function() assert.are.equal("", extract_hex("")) assert.are.equal("", extract_hex(" "))