Skip to content

Commit 2356e30

Browse files
committed
Use multicast to wait for new scan results
1 parent 2f74341 commit 2356e30

File tree

2 files changed

+101
-5
lines changed

2 files changed

+101
-5
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ netlink-sys = { version = "0.8.4" }
3838

3939
[dev-dependencies]
4040
env_logger = "0.9.0"
41+
anyhow = "1.0.100"
4142

4243
[dev-dependencies.tokio]
4344
version = "1.11.0"

examples/nl80211_trigger_scan.rs

Lines changed: 100 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,22 @@
22

33
use std::env::args;
44

5-
use futures::stream::TryStreamExt;
5+
use anyhow::{anyhow, Result};
6+
use futures::{stream::TryStreamExt, StreamExt};
7+
use netlink_packet_core::ParseableParametrized;
68
use netlink_packet_core::{DecodeError, ErrorContext};
9+
use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_REQUEST};
10+
use netlink_packet_generic::{
11+
ctrl::{
12+
nlas::{GenlCtrlAttrs, McastGrpAttrs},
13+
GenlCtrl, GenlCtrlCmd,
14+
},
15+
GenlMessage,
16+
};
17+
use netlink_sys::AsyncSocket;
18+
use wl_nl80211::{Nl80211Attr, Nl80211Command, Nl80211Message};
719

8-
fn main() -> Result<(), Box<dyn std::error::Error>> {
20+
fn main() -> Result<()> {
921
let argv: Vec<_> = args().collect();
1022

1123
if argv.len() < 2 {
@@ -29,8 +41,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
2941
Ok(())
3042
}
3143

32-
async fn dump_scan(if_index: u32) -> Result<(), Box<dyn std::error::Error>> {
33-
let (connection, handle, _) = wl_nl80211::new_connection()?;
44+
async fn dump_scan(if_index: u32) -> Result<()> {
45+
let (mut connection, handle, mut messages) = wl_nl80211::new_connection()?;
46+
47+
// Attach the connection socket to the multicast scan group to find out,
48+
// when the scan is finished.
49+
let socket = connection.socket_mut().socket_mut();
50+
socket.bind_auto()?;
51+
socket.add_membership(get_scan_multicast_id().await?)?;
52+
3453
tokio::spawn(connection);
3554

3655
let duration = 5000;
@@ -45,7 +64,23 @@ async fn dump_scan(if_index: u32) -> Result<(), Box<dyn std::error::Error>> {
4564
while let Some(msg) = scan_handle.try_next().await? {
4665
msgs.push(msg);
4766
}
48-
tokio::time::sleep(std::time::Duration::from_millis(duration.into())).await;
67+
68+
while let Some((message, _)) = messages.next().await {
69+
match message.payload {
70+
NetlinkPayload::InnerMessage(msg) => {
71+
let msg = Nl80211Message::parse_with_param(
72+
msg.payload.as_slice(),
73+
msg.header,
74+
)?;
75+
if msg.cmd == Nl80211Command::NewScanResults
76+
&& msg.attributes.contains(&Nl80211Attr::IfIndex(if_index))
77+
{
78+
break;
79+
}
80+
}
81+
_ => continue,
82+
}
83+
}
4984

5085
let mut dump = handle.scan().dump(if_index).execute().await;
5186
let mut msgs = Vec::new();
@@ -56,5 +91,65 @@ async fn dump_scan(if_index: u32) -> Result<(), Box<dyn std::error::Error>> {
5691
for msg in msgs {
5792
println!("{msg:?}");
5893
}
94+
5995
Ok(())
6096
}
97+
98+
async fn get_scan_multicast_id() -> Result<u32> {
99+
let (conn, mut handle, _) = wl_nl80211::new_connection()?;
100+
tokio::spawn(conn);
101+
102+
let mut nl_msg =
103+
NetlinkMessage::from(GenlMessage::from_payload(GenlCtrl {
104+
cmd: GenlCtrlCmd::GetFamily,
105+
nlas: vec![GenlCtrlAttrs::FamilyName("nl80211".to_owned())],
106+
}));
107+
108+
// To get the mcast groups for the nl80211 family, we must also set the
109+
// message type id
110+
nl_msg.header.message_type =
111+
handle.handle.resolve_family_id::<Nl80211Message>().await?;
112+
// This is a request, but not a dump. Which means, the family name has to be
113+
// specified, to obtain it's information.
114+
nl_msg.header.flags = NLM_F_REQUEST;
115+
116+
let responses = handle.handle.request(nl_msg).await?;
117+
let nl80211_family: Vec<Vec<GenlCtrlAttrs>> = responses
118+
.try_filter_map(|msg| async move {
119+
match msg.payload {
120+
NetlinkPayload::InnerMessage(genlmsg)
121+
if genlmsg.payload.cmd == GenlCtrlCmd::NewFamily
122+
&& genlmsg.payload.nlas.contains(
123+
&GenlCtrlAttrs::FamilyName("nl80211".to_owned()),
124+
) =>
125+
{
126+
Ok(Some(genlmsg.payload.nlas.clone()))
127+
}
128+
_ => Ok(None),
129+
}
130+
})
131+
.try_collect()
132+
.await?;
133+
134+
// Now get the mcid for "nl80211" "scan" group
135+
let scan_multicast_id = nl80211_family
136+
.first()
137+
.ok_or_else(|| anyhow!("Missing \"nl80211\" family"))?
138+
.iter()
139+
.find_map(|attr| match attr {
140+
GenlCtrlAttrs::McastGroups(mcast_groups) => Some(mcast_groups),
141+
_ => None,
142+
})
143+
.ok_or_else(|| anyhow!("Missing McastGroup attribute"))?
144+
.iter()
145+
.find(|grp| grp.contains(&McastGrpAttrs::Name("scan".to_owned())))
146+
.ok_or_else(|| anyhow!("Missing scan group"))?
147+
.iter()
148+
.find_map(|grp_attr| match grp_attr {
149+
McastGrpAttrs::Id(id) => Some(*id),
150+
_ => None,
151+
})
152+
.ok_or_else(|| anyhow!("No multicast id defined for scan group"))?;
153+
154+
Ok(scan_multicast_id)
155+
}

0 commit comments

Comments
 (0)