diff --git a/opcua/server/internal_server.py b/opcua/server/internal_server.py index 875b31ca2..bc158f965 100644 --- a/opcua/server/internal_server.py +++ b/opcua/server/internal_server.py @@ -4,11 +4,14 @@ """ from datetime import datetime, timedelta -from copy import copy +from copy import deepcopy import os import logging from threading import Lock from enum import Enum +from socket import INADDR_ANY # IPv4 '0.0.0.0' +IN6ADDR_ANY = '::' +from ipaddress import ip_address try: from urllib.parse import urlparse except ImportError: @@ -168,19 +171,47 @@ def add_endpoint(self, endpoint): def get_endpoints(self, params=None, sockname=None): self.logger.info("get endpoint") - if sockname: - # return to client the ip address it has access to - edps = [] - for edp in self.endpoints: - edp1 = copy(edp) - url = urlparse(edp1.EndpointUrl) - url = url._replace(netloc=sockname[0] + ":" + str(sockname[1])) - edp1.EndpointUrl = url.geturl() - edps.append(edp1) - return edps - return self.endpoints[:] + # return to client the endpoints it has access to + netloc = self._get_netloc(params, sockname) + edps = deepcopy(self.endpoints) + for edp in edps: + edp.EndpointUrl = InternalServer._replace_inaddr_any(edp.EndpointUrl, netloc) + return edps + + @staticmethod + def _get_netloc(params=None, sockname=None): + # find the ip:port as seen by our client. + netloc = None + if params and params.EndpointUrl: + # use ip:port as provided within client request params. + netloc = urlparse(params.EndpointUrl).netloc + if not netloc and sockname: + # use ip:port extracted from our local interface. + netloc = sockname[0] + ":" + str(sockname[1]) + return netloc + + @staticmethod + def _replace_inaddr_any(urlStr, netloc): + # If urlStr is '0.0.0.0:port' or '[::]:port', use netloc ip:port. + parseResult = urlparse(urlStr) + try: + hostip = ip_address(parseResult.hostname) + except ValueError: + hostip = None + if not netloc: + pass + elif hostip in (ip_address(INADDR_ANY), ip_address(IN6ADDR_ANY)): + urlStr = parseResult._replace(netloc=netloc).geturl() + return urlStr def find_servers(self, params): + servers = deepcopy(self._filter_servers(params)) + netloc = self._get_netloc(params) + for srv in servers: + srv.DiscoveryUrls = [self._replace_inaddr_any(dUrl, netloc) for dUrl in srv.DiscoveryUrls] + return servers + + def _filter_servers(self, params): if not params.ServerUris: return [desc.Server for desc in self._known_servers.values()] servers = [] @@ -311,7 +342,7 @@ def create_session(self, params, sockname=None): result.MaxRequestMessageSize = 65536 self.nonce = utils.create_nonce(32) result.ServerNonce = self.nonce - result.ServerEndpoints = self.get_endpoints(sockname=sockname) + result.ServerEndpoints = self.get_endpoints(params=params, sockname=sockname) return result