Skip to content

Commit a92e93a

Browse files
committed
Fix done message parse
Though the document of kernel (https://kernel.org/doc/html/next/userspace-api/netlink/intro.html#netlink-message-types) specify the format of netlink message with NLMSG_DONE, it also says that "Note that some implementations may issue custom NLMSG_DONE messages in reply to do action requests. In that case the payload is implementation-specific and may also be absent.". After searching the source code of kernel, we can find that 1. the format specified in the document is obeyed by most generic netlink but some generic netlink like this (https://elixir.bootlin.com/linux/v6.15/source/drivers/net/team/team_core.c#L2494) has no payload, so as a generic lib, we should not suppose the format of DoneMessage. it's sensible to just save the payload. 2. when use NLMSG_DONE as an end of multi messages, there will always be a NLM_F_MULTIPART in the flag and only in this case should we parse it as a DoneMessage, in other occassion like connector netlink (https://elixir.bootlin.com/linux/v6.15/source/drivers/connector/connector.c#L101), we should parse it as a common message. Signed-off-by: feng zhao <[email protected]>
1 parent 01e8dd1 commit a92e93a

File tree

2 files changed

+44
-102
lines changed

2 files changed

+44
-102
lines changed

src/done.rs

Lines changed: 32 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
// SPDX-License-Identifier: MIT
22

3-
use std::mem::size_of;
4-
53
use byteorder::{ByteOrder, NativeEndian};
64
use netlink_packet_utils::DecodeError;
75

@@ -11,99 +9,52 @@ const CODE: Field = 0..4;
119
const EXTENDED_ACK: Rest = 4..;
1210
const DONE_HEADER_LEN: usize = EXTENDED_ACK.start;
1311

14-
#[derive(Debug, PartialEq, Eq, Clone)]
12+
#[derive(Debug, Default, Clone, PartialEq, Eq)]
1513
#[non_exhaustive]
16-
pub struct DoneBuffer<T> {
17-
buffer: T,
14+
pub struct DoneMessage {
15+
pub payload: Vec<u8>,
1816
}
1917

20-
impl<T: AsRef<[u8]>> DoneBuffer<T> {
21-
pub fn new(buffer: T) -> DoneBuffer<T> {
22-
DoneBuffer { buffer }
23-
}
24-
25-
/// Consume the packet, returning the underlying buffer.
26-
pub fn into_inner(self) -> T {
27-
self.buffer
28-
}
29-
30-
pub fn new_checked(buffer: T) -> Result<Self, DecodeError> {
31-
let packet = Self::new(buffer);
32-
packet.check_buffer_length()?;
33-
Ok(packet)
34-
}
35-
36-
fn check_buffer_length(&self) -> Result<(), DecodeError> {
37-
let len = self.buffer.as_ref().len();
38-
if len < DONE_HEADER_LEN {
39-
Err(format!(
40-
"invalid DoneBuffer: length is {len} but DoneBuffer are \
41-
at least {DONE_HEADER_LEN} bytes"
42-
)
43-
.into())
18+
impl DoneMessage {
19+
pub fn code(&self) -> Option<i32> {
20+
if self.payload.len() < DONE_HEADER_LEN {
21+
None
4422
} else {
45-
Ok(())
23+
Some(NativeEndian::read_i32(&self.payload[CODE]))
4624
}
4725
}
4826

49-
/// Return the error code
50-
pub fn code(&self) -> i32 {
51-
let data = self.buffer.as_ref();
52-
NativeEndian::read_i32(&data[CODE])
53-
}
54-
}
55-
56-
impl<'a, T: AsRef<[u8]> + ?Sized> DoneBuffer<&'a T> {
57-
/// Return a pointer to the extended ack attributes.
58-
pub fn extended_ack(&self) -> &'a [u8] {
59-
let data = self.buffer.as_ref();
60-
&data[EXTENDED_ACK]
61-
}
62-
}
63-
64-
impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> DoneBuffer<&'a mut T> {
65-
/// Return a mutable pointer to the extended ack attributes.
66-
pub fn extended_ack_mut(&mut self) -> &mut [u8] {
67-
let data = self.buffer.as_mut();
68-
&mut data[EXTENDED_ACK]
27+
pub fn extended_ack(&self) -> Option<&[u8]> {
28+
if self.payload.len() < DONE_HEADER_LEN {
29+
None
30+
} else {
31+
Some(&self.payload[EXTENDED_ACK])
32+
}
6933
}
70-
}
7134

72-
impl<T: AsRef<[u8]> + AsMut<[u8]>> DoneBuffer<T> {
73-
/// set the error code field
74-
pub fn set_code(&mut self, value: i32) {
75-
let data = self.buffer.as_mut();
76-
NativeEndian::write_i32(&mut data[CODE], value)
35+
pub fn new_with_code<T: AsRef<[u8]>>(code: i32, extend_ack: &T) -> Self {
36+
let mut payload = vec![0; DONE_HEADER_LEN + extend_ack.as_ref().len()];
37+
NativeEndian::write_i32(&mut payload, code);
38+
payload[CODE.end..].copy_from_slice(extend_ack.as_ref());
39+
Self { payload }
7740
}
7841
}
7942

80-
#[derive(Debug, Default, Clone, PartialEq, Eq)]
81-
#[non_exhaustive]
82-
pub struct DoneMessage {
83-
pub code: i32,
84-
pub extended_ack: Vec<u8>,
85-
}
86-
8743
impl Emitable for DoneMessage {
8844
fn buffer_len(&self) -> usize {
89-
size_of::<i32>() + self.extended_ack.len()
45+
self.payload.len()
9046
}
9147
fn emit(&self, buffer: &mut [u8]) {
92-
let mut buffer = DoneBuffer::new(buffer);
93-
buffer.set_code(self.code);
94-
buffer
95-
.extended_ack_mut()
96-
.copy_from_slice(&self.extended_ack);
48+
buffer.copy_from_slice(&self.payload);
9749
}
9850
}
9951

100-
impl<T: AsRef<[u8]>> Parseable<DoneBuffer<&T>> for DoneMessage {
52+
impl<T: AsRef<[u8]>> Parseable<T> for DoneMessage {
10153
type Error = DecodeError;
10254

103-
fn parse(buf: &DoneBuffer<&T>) -> Result<DoneMessage, Self::Error> {
55+
fn parse(buf: &T) -> Result<DoneMessage, Self::Error> {
10456
Ok(DoneMessage {
105-
code: buf.code(),
106-
extended_ack: buf.extended_ack().to_vec(),
57+
payload: buf.as_ref().to_vec(),
10758
})
10859
}
10960
}
@@ -114,22 +65,18 @@ mod tests {
11465

11566
#[test]
11667
fn serialize_and_parse() {
117-
let expected = DoneMessage {
118-
code: 5,
119-
extended_ack: vec![1, 2, 3],
120-
};
121-
68+
let expected = DoneMessage::new_with_code(5, &[1, 2, 3]);
12269
let len = expected.buffer_len();
123-
assert_eq!(len, size_of::<i32>() + expected.extended_ack.len());
70+
assert_eq!(
71+
len,
72+
size_of::<i32>() + expected.extended_ack().unwrap().len()
73+
);
12474

12575
let mut buf = vec![0; len];
12676
expected.emit(&mut buf);
12777

128-
let done_buf = DoneBuffer::new(&buf);
129-
assert_eq!(done_buf.code(), expected.code);
130-
assert_eq!(done_buf.extended_ack(), &expected.extended_ack);
131-
132-
let got = DoneMessage::parse(&done_buf).unwrap();
133-
assert_eq!(got, expected);
78+
let got = DoneMessage::parse(&buf);
79+
assert!(got.is_ok());
80+
assert_eq!(got.unwrap(), expected);
13481
}
13582
}

src/message.rs

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ use netlink_packet_utils::DecodeError;
66

77
use crate::{
88
payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN},
9-
DoneBuffer, DoneMessage, Emitable, ErrorBuffer, ErrorMessage,
10-
NetlinkBuffer, NetlinkDeserializable, NetlinkHeader, NetlinkPayload,
11-
NetlinkSerializable, Parseable,
9+
DoneMessage, Emitable, ErrorBuffer, ErrorMessage, NetlinkBuffer,
10+
NetlinkDeserializable, NetlinkHeader, NetlinkPayload, NetlinkSerializable,
11+
Parseable, NLM_F_MULTIPART,
1212
};
1313

1414
/// Represent a netlink message.
@@ -103,10 +103,11 @@ where
103103
Error(msg)
104104
}
105105
NLMSG_NOOP => Noop,
106-
NLMSG_DONE => {
107-
let msg = DoneBuffer::new_checked(&bytes)
108-
.and_then(|buf| DoneMessage::parse(&buf))?;
109-
Done(msg)
106+
// only parse message_type of NLMSG_DONE when flag has
107+
// NLM_F_MULTIPART because some special netlink like
108+
// connector use NLMSG_DONE for all the message
109+
NLMSG_DONE if header.flags & NLM_F_MULTIPART == NLM_F_MULTIPART => {
110+
Done(DoneMessage::parse(&bytes)?)
110111
}
111112
NLMSG_OVERRUN => Overrun(bytes.to_vec()),
112113
message_type => match I::deserialize(&header, bytes) {
@@ -205,11 +206,9 @@ mod tests {
205206

206207
#[test]
207208
fn test_done() {
208-
let header = NetlinkHeader::default();
209-
let done_msg = DoneMessage {
210-
code: 0,
211-
extended_ack: vec![6, 7, 8, 9],
212-
};
209+
let mut header = NetlinkHeader::default();
210+
header.flags |= NLM_F_MULTIPART;
211+
let done_msg = DoneMessage::new_with_code(0, &[6, 7, 8, 9]);
213212
let mut want = NetlinkMessage::new(
214213
header,
215214
NetlinkPayload::<FakeNetlinkInnerMessage>::Done(done_msg.clone()),
@@ -221,16 +220,12 @@ mod tests {
221220
len,
222221
header.buffer_len()
223222
+ size_of::<i32>()
224-
+ done_msg.extended_ack.len()
223+
+ done_msg.extended_ack().unwrap().len()
225224
);
226225

227226
let mut buf = vec![1; len];
228227
want.emit(&mut buf);
229228

230-
let done_buf = DoneBuffer::new(&buf[header.buffer_len()..]);
231-
assert_eq!(done_buf.code(), done_msg.code);
232-
assert_eq!(done_buf.extended_ack(), &done_msg.extended_ack);
233-
234229
let got = NetlinkMessage::parse(&NetlinkBuffer::new(&buf)).unwrap();
235230
assert_eq!(got, want);
236231
}

0 commit comments

Comments
 (0)