Skip to content

Commit

Permalink
Add control of parallel scanning processes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lab Owner committed Jun 19, 2023
1 parent 7e34221 commit 2a79de8
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 20 deletions.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2022 Joel Pothering
Copyright (c) 2023 Joel Pothering

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ Works fine generally, basically like the existing HA unifi_direct, and allows me

History
-
### v0.0.7
- Added device_tracker.py option ```--processes``` to control the number of concurrent processes run during AP scanning, or to run them sequentially (i.e. ```--processes=0```).
- Tested under Python11 and Debian12.
-
### v0.0.6
- Update README and docstrings corresponding to changes in HA 2022.9.
### v0.0.5
Expand Down
6 changes: 5 additions & 1 deletion app/device_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
import paho.mqtt.client as mqtt
import paho.mqtt.subscribe as subscribe
import ssl
from multiprocessing import Queue
from multiprocessing import Process
import unifi_tracker as unifi
Expand All @@ -30,6 +29,7 @@
UseHostKeysFile = False
SshTimeout = None
MaxIdleTime = None
Processes = None

Log = logging.getLogger(Logger_name)
AP_hosts = []
Expand Down Expand Up @@ -130,6 +130,8 @@ def process(last_clients):
unifiTracker.SshTimeout = SshTimeout
if MaxIdleTime is not None:
unifiTracker.MaxIdleTime = MaxIdleTime
if Processes is not None:
unifiTracker.Processes = Processes
for i in range(Snapshot_loop_count):
try:
last_clients, added, deleted = unifiTracker.scan_aps(ssh_username=Unifi_ssh_username,
Expand Down Expand Up @@ -173,6 +175,7 @@ def main():
ap.add_argument("--usehostkeys", required=False, action='store_true', default=UseHostKeysFile, help="Use known_hosts file.")
ap.add_argument("--sshTimeout", type=float, required=False, action='store', default=SshTimeout, help="SSH timeout in secs.")
ap.add_argument("--maxIdleTime", type=int, required=False, action='store', default=MaxIdleTime, help="Maximum AP client idle time in secs.")
ap.add_argument("--processes", type=int, required=False, action='store', default=Processes, help="Scans run in parallel; set to 0 for sequential.")
ap.add_argument("--mqtthost", type=str, required=False, action='store', default=Mqtt_host, help="MQTT host.")
ap.add_argument("--mqttport", type=int, required=False, action='store', default=Mqtt_port, help="MQTT port.")
ap.add_argument("--mqtts", required=False, action='store_true', default=False, help="Use MQTT TLS.")
Expand All @@ -195,6 +198,7 @@ def main():
UseHostKeysFile = args.usehostkeys
SshTimeout = args.sshTimeout
MaxIdleTime = args.maxIdleTime
Processes = args.processes
Log.debug(AP_hosts)
Mqtt_host = args.mqtthost
Mqtt_port = args.mqttport
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = unifi_tracker
version = 0.0.6
version = 0.0.7
author = Joel P.
author_email = [email protected]
description = Track the comings and goings of WiFi clients on multiple Unifi APs and generate a diff between scans.
Expand Down
2 changes: 1 addition & 1 deletion src/unifi_tracker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
'''Track the comings and goings of WiFi clients on multiple Unifi APs and generate a diff between scans.'''

__version__ = '0.0.6'
__version__ = '0.0.7'

from .unifi_tracker import *
58 changes: 44 additions & 14 deletions src/unifi_tracker/unifi_tracker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import json
import logging
from paramiko import WarningPolicy
Expand Down Expand Up @@ -29,6 +30,8 @@ def __init__(self, useHostKeys: bool=False):
self.UNIFI_CLIENT_TABLE = 'sta_table'
# Some reasonable limit to number of hosts to scan in parallel.
self.MAX_AP_HOST_SCANS = 32
# Scanning processes run in parallel; set to 0 to run serially.
self._processes = os.cpu_count()

@property
def UseHostKeys(self):
Expand Down Expand Up @@ -57,6 +60,15 @@ def MaxIdleTime(self):
def MaxIdleTime(self, value: int):
self._maxIdleTime = value

@property
def Processes(self):
'''Scanning processes run in parallel; set to 0 for sequential processing.'''
return self._processes

@Processes.setter
def Processes(self, value: int):
self._processes = value

def exec_ssh_cmdline(self, user: str, host: str, cmdline: str):
'''Remotely execute command via SSH'''
try:
Expand Down Expand Up @@ -93,8 +105,23 @@ def get_ap_clients(self, ssh_username: str, ap_host: str):

def get_ap_mac_clients(self, ssh_username: str, ap_host: str):
'''MAC to client JSON from a Unifi AP'''
_LOGGER.debug(f'Scanning {ap_host}.')
return {client.get('mac').upper(): client for client in self.get_ap_clients(ssh_username=ssh_username, ap_host=ap_host)}

def parallel_scan(self, ssh_username: str, ap_hosts: list[str]):
'''List of results of parallel calls to get_ap_mac_clients.'''
with Pool(processes=self._processes) as pool:
_LOGGER.debug(f'Running {self._processes} scans in parallel.')
return pool.starmap(self.get_ap_mac_clients, [(ssh_username, ap_host) for ap_host in ap_hosts])

def sequential_scan(self, ssh_username: str, ap_hosts: list[str]):
'''List of results of sequential calls to get_ap_mac_clients'''
all_ap_mac_clients = []
for ap_host in ap_hosts:
macs_res = self.get_ap_mac_clients(ssh_username, ap_host)
all_ap_mac_clients.append(macs_res)
return all_ap_mac_clients

def scan_aps(self, ssh_username: str, ap_hosts: list[str], last_mac_clients: dict={}):
'''Retrieve and merge clients from all APs; diff with last retrieved.
Return tuple: dict of clients, list of client adds, list of client deletes.
Expand All @@ -106,20 +133,23 @@ def scan_aps(self, ssh_username: str, ap_hosts: list[str], last_mac_clients: dic
mac_clients = {}
added = []
deleted = []
with Pool() as pool:
for ap_mac_clients in pool.starmap(self.get_ap_mac_clients, [(ssh_username, ap_host) for ap_host in ap_hosts]):
if self._maxIdleTime is None:
mac_clients.update(ap_mac_clients)
else:
# Filter on clients below idle time threshold
for mac, client in ap_mac_clients.items():
idletime = client["idletime"] if 'idletime' in client else 0
_LOGGER.debug(f'{mac} idletime={idletime}')
if idletime > self._maxIdleTime:
if mac in last_mac_clients:
_LOGGER.info(f'{mac} exceeded idle time threshold; excluding.')
continue
mac_clients[mac] = client
if self._processes == 0:
all_ap_mac_clients = self.sequential_scan(ssh_username, ap_hosts)
else:
all_ap_mac_clients = self.parallel_scan(ssh_username, ap_hosts)
for ap_mac_clients in all_ap_mac_clients:
if self._maxIdleTime is None:
mac_clients.update(ap_mac_clients)
else:
# Filter on clients below idle time threshold
for mac, client in ap_mac_clients.items():
idletime = client["idletime"] if 'idletime' in client else 0
_LOGGER.debug(f'{mac} idletime={idletime}')
if idletime > self._maxIdleTime:
if mac in last_mac_clients:
_LOGGER.info(f'{mac} exceeded idle time threshold; excluding.')
continue
mac_clients[mac] = client
for mac, client in mac_clients.items():
if mac not in last_mac_clients:
added.append(mac)
Expand Down
4 changes: 2 additions & 2 deletions tests/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

cp ../src/unifi_tracker/unifi_tracker.py .

python test_diff.py
python test_property_setters.py
python3 test_diff.py
python3 test_property_setters.py
13 changes: 13 additions & 0 deletions tests/test_property_setters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import unittest
import unifi_tracker as unifi

Expand Down Expand Up @@ -45,5 +46,17 @@ def test_maxIdleTime_setter(self):
unifi_tracker.MaxIdleTime = maxIdleTime + 1
assert(maxIdleTime + 1 == unifi_tracker.MaxIdleTime)

def test_processes_default(self):
# Processes defaults to os.cpu_count()
processes = os.cpu_count()
unifi_tracker = unifi.UnifiTracker()
assert(processes == unifi_tracker.Processes)

def test_processes_setter(self):
unifi_tracker = unifi.UnifiTracker()
processes = 0
unifi_tracker.Processes = processes + 1
assert(processes + 1 == unifi_tracker.Processes)

if __name__ == "__main__":
unittest.main()

0 comments on commit 2a79de8

Please sign in to comment.