Skip to content

Commit

Permalink
Address envelope
Browse files Browse the repository at this point in the history
  • Loading branch information
raulikak committed May 29, 2024
1 parent 101665e commit 603a3c4
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 45 deletions.
51 changes: 31 additions & 20 deletions tcsfw/address.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def get_host(self) -> Optional['AnyAddress']:
"""Get host or self"""
return self

def get_origin(self) -> 'AnyAddress':
"""Get the origin, loosely after URL naming, which is host and possible port"""
def open_envelope(self) -> 'AnyAddress':
"""Open address envelope, if any. If none, return this address"""
return self

def get_protocol_port(self) -> Optional[Tuple[Protocol, int]]:
Expand Down Expand Up @@ -203,10 +203,9 @@ def get_tag(cls, addresses: Iterable[AnyAddress]) -> Optional[EntityTag]:
@classmethod
def parse_address(cls, address: str) -> AnyAddress:
"""Parse any address type from string, type given as 'type|address'"""
ad, _, net = address.partition("@")
if net != "":
# NOTE: Network object is not properly restored
return NetworkAddress(Network(net), cls.parse_address(ad))
ad, _, con = address.partition("(")
if con and con.endswith(")"):
return AddressEnvelope(cls.parse_address(ad), cls.parse_address(con[:-1]))
v, _, t = address.rpartition("|")
if v == "":
t, v = "ip", t # default is IP
Expand All @@ -223,6 +222,9 @@ def parse_address(cls, address: str) -> AnyAddress:
@classmethod
def parse_endpoint(cls, value: str) -> AnyAddress:
"""Parse address or endpoint"""
ad, _, con = value.partition("(")
if con and con.endswith(")"):
return AddressEnvelope(cls.parse_address(ad), cls.parse_endpoint(con[:-1]))
a, _, p = value.partition("/")
addr = cls.parse_address(a)
if p == "":
Expand Down Expand Up @@ -486,10 +488,10 @@ def __repr__(self):

class Network:
"""Network"""
def __init__(self, name: str, ip_networks: List[IPv4Network | IPv6Network] = None) -> None:
def __init__(self, name: str, ip_network: Optional[IPv4Network | IPv6Network] = None) -> None:
self.name = name
# NOTE: Equality etc. is only evaluated by name
self.ip_networks = [] if ip_networks is None else ip_networks
self.ip_network = ip_network

def is_local(self, address: 'AnyAddress') -> bool:
"""Is local address for this network?"""
Expand All @@ -499,10 +501,7 @@ def is_local(self, address: 'AnyAddress') -> bool:
if h.is_multicast() or h.is_null() or not isinstance(h, IPAddress):
return True
# FIXME: Broadcast for IPv6 not implemented pylint: disable=fixme
for m in self.ip_networks:
if h.data in m:
return True
return False
return h.data in self.ip_network

def __eq__(self, other) -> bool:
return isinstance(other, Network) and self.name == other.name
Expand All @@ -523,11 +522,14 @@ class Networks:
Internet = Network("Internet") # Internet


class NetworkAddress(AnyAddress):
"""Address with explicit network"""
def __init__(self, network: Network, address: AnyAddress) -> None:
self.network = network
class AddressEnvelope(AnyAddress):
"""Address envelope carrying content address"""
def __init__(self, address: AnyAddress, content: AnyAddress):
self.address = address
self.content = content

def open_envelope(self) -> 'AnyAddress':
return self.content

def get_ip_address(self) -> Optional[IPAddress]:
return self.address.get_ip_address()
Expand All @@ -542,7 +544,7 @@ def get_protocol_port(self) -> Optional[Tuple[Protocol, int]]:
return self.address.get_protocol_port()

def change_host(self, host: 'AnyAddress') -> Self:
return NetworkAddress(self.network, self.address.change_host(host))
return AddressEnvelope(self.address.change_host(host), self.content)

def is_null(self) -> bool:
return self.address.is_null()
Expand All @@ -566,6 +568,15 @@ def priority(self) -> int:
return self.address.priority() + 1

def get_parseable_value(self) -> str:
if self.network == Networks.Default:
return self.address.get_parseable_value()
return f"{self.address.get_parseable_value()}@{self.network.name}"
return f"{self.address.get_parseable_value()}({self.content.get_parseable_value()})"

def __eq__(self, other):
if not isinstance(other, AddressEnvelope):
return False
return self.address == other.address and self.content == other.content

def __hash__(self):
return self.address.__hash__() ^ self.content.__hash__()

def __repr__(self) -> str:
return f"{self.address}({self.content})"
11 changes: 6 additions & 5 deletions tcsfw/builder_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, name="Unnamed system"):
def network(self, subnet="") -> 'NetworkBuilder':
if subnet:
return NetworkBackend(self, subnet)
return NetworkBuilder(self)
return NetworkBackend(self)

def device(self, name="") -> 'HostBackend':
name = name or self._free_host_name("Device")
Expand Down Expand Up @@ -229,7 +229,7 @@ def external_activity(self, value: ExternalActivity) -> Self:
return self

def in_networks(self, *network: NetworkBuilder) -> Self:
self.entity.network = [n.network for n in network]
self.entity.networks = [n.network for n in network]

def software(self, name: Optional[str] = None) -> 'SoftwareBackend':
if name is None:
Expand Down Expand Up @@ -455,10 +455,11 @@ class NetworkBackend(NetworkBuilder):
def __init__(self, parent: SystemBackend, name=""):
self.parent = parent
self.name = name
self.network = Network(name) if name else parent.system.network
self.network = Network(name) if name else parent.system.networks[0]

def mask(self, *mask: str) -> Self:
self.network.mask = [ipaddress.ip_network(m) for m in mask]
def mask(self, mask: str) -> Self:
self.network.mask = ipaddress.ip_network(mask)
return self

def __repr__(self) -> str:
return self.name
Expand Down
37 changes: 28 additions & 9 deletions tcsfw/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(self, name: str):
self.visual = False # show visual image?
self.children: List[Addressable] = []
self.components: List[NodeComponent] = []
self.network = Networks.Default # usually copied from parent
self.networks = [Networks.Default] # usually copied from parent
self.external_activity = ExternalActivity.BANNED

def get_children(self) -> Iterable['Entity']:
Expand Down Expand Up @@ -168,6 +168,14 @@ def is_relevant(self) -> bool:
def is_admin(self) -> bool:
return self.host_type == HostType.ADMINISTRATIVE

def get_ip_network(self, address: IPAddress) -> Network:
"""Resolve IP network for IP address"""
if len(self.networks) > 1:
for nw in self.networks:
if nw.ip_network and address in nw.ip_network:
return nw
return nw[0] # the default

def get_connections(self, relevant_only=True) -> List[Connection]:
"""Get relevant conneciions"""
cs = []
Expand Down Expand Up @@ -290,6 +298,13 @@ def is_multicast(self) -> bool:
return True
return self.addresses and any(a.is_multicast() for a in self.addresses)

def get_ip_network(self, address: IPAddress) -> Network:
"""Resolve IP network for IP address"""
if self.parent and self.parent.networks == self.networks:
# avoid repeating the same check
return self.parent.get_ip_network(address)
return super().get_ip_network(address)

def get_addresses(self, ads: Set[AnyAddress] = None) -> Set[AnyAddress]:
"""Get all addresses"""
ads = set() if ads is None else ads
Expand Down Expand Up @@ -342,7 +357,7 @@ def __init__(self, parent: 'IoTSystem', name: str, tag: Optional[EntityTag] = No
self.addresses.add(tag)
self.concept_name = "node"
self.parent = parent
self.network = parent.network
self.networks = parent.networks
self.visual = True
self.ignore_name_requests: Set[DNSName] = set()
self.connections: List[Connection] = [] # connections initiated here
Expand Down Expand Up @@ -391,7 +406,7 @@ def __init__(self, name: str, parent: Addressable):
super().__init__(name)
self.concept_name = "service"
self.parent = parent
self.network = parent.network
self.networks = parent.networks
self.protocol: Optional[Protocol] = None # known protocol
self.host_type = parent.host_type
self.con_type = ConnectionType.UNKNOWN
Expand Down Expand Up @@ -447,7 +462,7 @@ def __init__(self, name="IoT system"):
self.concept_name = "system"
self.status = Status.EXPECTED
# network mask(s)
self.network = Network("local", ip_networks=[ipaddress.ip_network("192.168.0.0/16")]) # reasonable default
self.networks = [Network("local", ip_network=ipaddress.ip_network("192.168.0.0/16"))] # reasonable default
# online resources
self.online_resources: Dict[str, str] = {}
# original entities and connections
Expand Down Expand Up @@ -477,7 +492,10 @@ def is_host_reachable(self) -> bool:

def is_external(self, address: AnyAddress) -> bool:
"""Is an external network address?"""
return not self.network.is_local(address)
for nw in self.networks:
if nw.is_local(address):
return False
return True

def learn_named_address(self, name: Union[DNSName, EntityTag], address: Optional[AnyAddress]) -> Tuple[Host, bool]:
"""Learn addresses for host, return the named host and if any changes"""
Expand Down Expand Up @@ -574,10 +592,11 @@ def get_system(self) -> 'IoTSystem':
def get_endpoint(self, address: AnyAddress) -> Addressable:
"""Get or create a new endpoint, service or host"""
h_add = address.get_host()
e_add = address.open_envelope() # scan from inside is in envelope
for e in self.children:
if h_add in e.addresses:
if isinstance(address, EndpointAddress):
e = e.get_endpoint(address) or e
if isinstance(e_add, EndpointAddress):
e = e.get_endpoint(e_add) or e
break
else:
# create new host and possibly service
Expand All @@ -590,8 +609,8 @@ def get_endpoint(self, address: AnyAddress) -> Addressable:
e.addresses.add(h_add)
e.external_activity = ExternalActivity.UNLIMITED # we know nothing about its behavior
self.children.append(e)
if isinstance(address, EndpointAddress) and e.is_host():
e = e.create_service(address)
if isinstance(e_add, EndpointAddress) and e.is_host():
e = e.create_service(e_add)
return e

def new_connection(self, source: Tuple[Addressable, AnyAddress],
Expand Down
6 changes: 3 additions & 3 deletions tcsfw/shell_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from io import BytesIO, TextIOWrapper
import re
from typing import Any, Dict, List, Set, Tuple
from tcsfw.address import Addresses, AnyAddress, EndpointAddress, HWAddresses, IPAddress
from tcsfw.address import AddressEnvelope, Addresses, AnyAddress, EndpointAddress, HWAddresses, IPAddress
from tcsfw.components import OperatingSystem
from tcsfw.event_interface import EventInterface, PropertyEvent
from tcsfw.model import IoTSystem
Expand Down Expand Up @@ -113,7 +113,7 @@ def process_endpoint(self, endpoint: AnyAddress, stream: BytesIO, interface: Eve
source: EvidenceSource):
columns: Dict[str, int] = {}
local_ads = set()
services = set()
services: Set[EndpointAddress] = set()
conns = set()

node = self.system.get_endpoint(endpoint)
Expand Down Expand Up @@ -171,7 +171,7 @@ def process_endpoint(self, endpoint: AnyAddress, stream: BytesIO, interface: Eve
if self.send_events:
evidence = Evidence(source)
for addr in sorted(services):
scan = ServiceScan(evidence, addr)
scan = ServiceScan(evidence, endpoint=AddressEnvelope(tag, addr) if tag else addr)
interface.service_scan(scan)
# NOTE: Create host scan event to report missing services

Expand Down
7 changes: 4 additions & 3 deletions tcsfw/traffic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import datetime
from typing import Any, Callable, Tuple, Set, Optional, Self, Dict

from tcsfw.address import HWAddress, IPAddress, HWAddresses, IPAddresses, Protocol, EndpointAddress, AnyAddress, \
Addresses
from tcsfw.address import EntityTag, HWAddress, IPAddress, HWAddresses, IPAddresses, Protocol, EndpointAddress, \
AnyAddress, Addresses
from tcsfw.property import PropertyKey


Expand Down Expand Up @@ -117,10 +117,11 @@ def __init__(self, evidence: Evidence, endpoint: EndpointAddress, service_name="
self.service_name = service_name

def get_data_json(self, _id_resolver: Callable[[Any], Any]) -> Dict:
return {
r = {
"endpoint": self.endpoint.get_parseable_value(),
"service": self.service_name,
}
return r

@classmethod
def decode_data_json(cls, evidence: Evidence, data: Dict, _entity_resolver: Callable[[Any], Any]) -> 'ServiceScan':
Expand Down
16 changes: 15 additions & 1 deletion tests/test_address.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from tcsfw.address import Addresses, DNSName, EndpointAddress, HWAddress, HWAddresses, IPAddress, IPAddresses, Protocol
from tcsfw.address import AddressEnvelope, Addresses, DNSName, EndpointAddress, HWAddress, HWAddresses, IPAddress, IPAddresses, Protocol


def test_hw_address():
Expand Down Expand Up @@ -84,3 +84,17 @@ def test_hw_address_generation():
ip = IPAddress.new("192.168.0.2")
hw = HWAddress.from_ip(ip)
assert hw == HWAddress('40:00:c0:a8:00:02')


def test_parse_address_envelope():
a = Addresses.parse_address("1.2.3.4(weird.com|name)")
assert isinstance(a, AddressEnvelope)
assert a.address == IPAddress.new("1.2.3.4")
assert a.content == DNSName("weird.com")


def test_parse_endpoint_address_envelope():
a = Addresses.parse_endpoint("example.com|name(1.2.3.4/udp:1234)")
assert isinstance(a, AddressEnvelope)
assert a.address == DNSName("example.com")
assert a.content == EndpointAddress.ip("1.2.3.4", Protocol.UDP, 1234)
11 changes: 7 additions & 4 deletions tests/test_shell_ss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,22 @@ def test_shell_ss_pass():
assert len(hs) == 6
h = hs[0]
assert h.status_verdict() == (Status.EXPECTED, Verdict.PASS)
assert len(h.children) == 5
assert len(h.children) == 6
s = h.children[0]
assert s.long_name() == "Device SSH:22"
assert s.status_verdict() == (Status.EXPECTED, Verdict.PASS)
s = h.children[1]
assert s.long_name() == "Device UDP:68"
assert s.long_name() == "Device TCP:51337"
assert s.status_verdict() == (Status.UNEXPECTED, Verdict.FAIL)
s = h.children[2]
assert s.long_name() == "Device TCP:41337"
assert s.long_name() == "Device UDP:68"
assert s.status_verdict() == (Status.UNEXPECTED, Verdict.FAIL)
s = h.children[3]
assert s.long_name() == "Device UDP:1194"
assert s.long_name() == "Device TCP:41337"
assert s.status_verdict() == (Status.UNEXPECTED, Verdict.FAIL)
s = h.children[4]
assert s.long_name() == "Device UDP:1194"
assert s.status_verdict() == (Status.UNEXPECTED, Verdict.FAIL)
s = h.children[5]
assert s.long_name() == "Device UDP:123"
assert s.status_verdict() == (Status.UNEXPECTED, Verdict.FAIL)

0 comments on commit 603a3c4

Please sign in to comment.