forked from dadevel/wg-netns
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwg-netns.py
executable file
·231 lines (171 loc) · 7.99 KB
/
wg-netns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
#!/usr/bin/env python3
from argparse import ArgumentParser, RawDescriptionHelpFormatter
from pathlib import Path
import json
import os
import subprocess
import sys
NETNS_CONFIG_DIR = Path('/etc/netns')
DEBUG_LEVEL = 0
SHELL = Path('/bin/sh')
def main(args):
global NETNS_CONFIG_DIR
global DEBUG_LEVEL
global SHELL
entrypoint = ArgumentParser(
formatter_class=RawDescriptionHelpFormatter,
epilog=(
'environment variables:\n'
f' NETNS_CONFIG_DIR network namespace config directory, default: {NETNS_CONFIG_DIR}\n'
f' DEBUG_LEVEL print stack traces, 0 or 1, default: {DEBUG_LEVEL}\n'
f' SHELL program for execution of shell hooks, default: {SHELL}\n'
),
)
subparsers = entrypoint.add_subparsers(dest='action', required=True)
parser = subparsers.add_parser('up', help='setup namespace and associated interfaces')
parser.add_argument('profile', type=lambda x: Path(x).expanduser(), help='path to profile')
parser = subparsers.add_parser('down', help='teardown namespace and associated interfaces')
parser.add_argument('-f', '--force', action='store_true', help='ignore errors')
parser.add_argument('profile', type=lambda x: Path(x).expanduser(), help='path to profile')
opts = entrypoint.parse_args(args)
try:
NETNS_CONFIG_DIR = Path(os.environ.get('NETNS_CONFIG_DIR', NETNS_CONFIG_DIR))
DEBUG_LEVEL = int(os.environ.get('DEBUG_LEVEL', DEBUG_LEVEL))
SHELL = Path(os.environ.get('SHELL', SHELL))
except Exception as e:
raise RuntimeError(f'failed to load environment variable: {e} (e.__class__.__name__)') from e
if opts.action == 'up':
setup_action(opts.profile)
elif opts.action == 'down':
teardown_action(opts.profile, check=not opts.force)
else:
raise RuntimeError('congratulations, you reached unreachable code')
def setup_action(path):
namespace = profile_read(path)
try:
namespace_setup(namespace)
except BaseException:
namespace_teardown(namespace, check=False)
raise
def teardown_action(path, check=True):
namespace = profile_read(path)
namespace_teardown(namespace, check=check)
def profile_read(path):
with open(path) as file:
return json.load(file)
def namespace_setup(namespace):
if namespace.get('pre-up'):
ip_netns_shell(namespace['pre-up'], netns=namespace)
namespace_create(namespace)
if str.lower(namespace.get('no-resolvconf-write', '')) != 'true':
namespace_resolvconf_write(namespace)
for interface in namespace['interfaces']:
interface_setup(interface, namespace)
if namespace.get('post-up'):
ip_netns_shell(namespace['post-up'], netns=namespace)
if namespace.get('no-netns-post-up'):
run(namespace['no-netns-post-up'])
def namespace_create(namespace):
ip('netns', 'add', namespace['name'])
ip('-n', namespace['name'], 'link', 'set', 'dev', 'lo', 'up')
def namespace_resolvconf_write(namespace):
if namespace.get('dns-server'):
content = '\n'.join(f'nameserver {server}' for server in namespace['dns-server'])
if content:
NETNS_CONFIG_DIR.joinpath(namespace['name']).mkdir(parents=True, exist_ok=True)
NETNS_CONFIG_DIR.joinpath(namespace['name']).joinpath('resolv.conf').write_text(content)
def namespace_teardown(namespace, check=True):
if namespace.get('pre-down'):
ip_netns_shell(namespace['pre-down'], netns=namespace)
for interface in namespace['interfaces']:
if check_for_interface_in_namespace(interface, namespace):
# remove only existing interfaces
interface_teardown(interface, namespace)
namespace_delete(namespace)
namespace_resolvconf_delete(namespace)
if namespace.get('post-down'):
ip_netns_shell(namespace['post-down'], netns=namespace)
def namespace_delete(namespace, check=True):
ip('netns', 'delete', namespace['name'], check=check)
def namespace_resolvconf_delete(namespace):
path = f"{NETNS_CONFIG_DIR}/{namespace['name']}/resolv.conf"
if os.path.exists(path):
os.unlink(path)
try:
os.rmdir(NETNS_CONFIG_DIR)
except OSError:
pass
def interface_setup(interface, namespace):
interface_create(interface, namespace)
interface_configure_wireguard(interface, namespace)
for peer in interface['peers']:
peer_setup(peer, interface, namespace)
interface_assign_addresses(interface, namespace)
interface_bring_up(interface, namespace)
interface_create_routes(interface, namespace)
def interface_create(interface, namespace):
ip('link', 'add', interface['name'], 'type', 'wireguard')
ip('link', 'set', interface['name'], 'netns', namespace['name'])
def interface_configure_wireguard(interface, namespace):
wg('set', interface['name'], 'listen-port', interface.get('listen-port', 0), netns=namespace)
wg('set', interface['name'], 'fwmark', interface.get('fwmark', 0), netns=namespace)
wg('set', interface['name'], 'private-key', '/dev/stdin', stdin=interface['private-key'], netns=namespace)
def interface_assign_addresses(interface, namespace):
for address in interface['address']:
ip('-n', namespace['name'], '-6' if ':' in address else '-4', 'address', 'add', address, 'dev', interface['name'])
def interface_bring_up(interface, namespace):
ip('-n', namespace['name'], 'link', 'set', 'dev', interface['name'], 'mtu', interface.get('mtu', 1420), 'up')
def interface_create_routes(interface, namespace):
for peer in interface['peers']:
for network in peer.get('allowed-ips', ()):
ip('-n', namespace['name'], '-6' if ':' in network else '-4', 'route', 'add', network, 'dev', interface['name'])
def interface_teardown(interface, namespace, check=True):
ip('-n', namespace['name'], 'link', 'set', interface['name'], 'down', check=check)
ip('-n', namespace['name'], 'link', 'delete', interface['name'], check=check)
def check_for_interface_in_namespace(interface, namespace, check=True):
"""
checks the presence of given interface and returns True if it is present
"""
interfaces_json = ip('-br', '-j', '-n', namespace['name'], 'link', 'list', check=check, capture=True)
interfaces = json.loads(interfaces_json)
for ip_interface in interfaces:
if ip_interface['ifname'] == interface:
return True
return False
def peer_setup(peer, interface, namespace):
options = [
'peer', peer['public-key'],
'preshared-key', '/dev/stdin' if peer.get('preshared-key') else '/dev/null',
'persistent-keepalive', peer.get('persistent-keepalive', 0),
]
if peer.get('endpoint'):
options.extend(('endpoint', peer.get('endpoint')))
if peer.get('allowed-ips'):
options.extend(('allowed-ips', ','.join(peer['allowed-ips'])))
wg('set', interface['name'], *options, stdin=peer.get('preshared-key'), netns=namespace)
def wg(*args, **kwargs):
return ip_netns_exec('wg', *args, **kwargs)
def ip_netns_shell(*args, **kwargs):
return ip_netns_exec(SHELL, '-c', *args, **kwargs)
def ip_netns_exec(*args, netns=None, **kwargs):
return ip('netns', 'exec', netns['name'], *args, **kwargs)
def ip(*args, **kwargs):
return run('ip', *args, **kwargs)
def run(*args, stdin=None, check=True, capture=False):
args = [str(item) if item is not None else '' for item in args]
if DEBUG_LEVEL:
print('>', ' '.join(args), file=sys.stderr)
process = subprocess.run(args, input=stdin, text=True, capture_output=capture)
if check and process.returncode != 0:
error = process.stderr.strip() if process.stderr else f'exit code {process.returncode}'
raise RuntimeError(f'subprocess failed: {" ".join(args)}: {error}')
return process.stdout
if __name__ == '__main__':
try:
main(sys.argv[1:])
sys.exit(0)
except Exception as e:
if DEBUG_LEVEL:
raise
print(f'error: {e} ({e.__class__.__name__})', file=sys.stderr)
sys.exit(2)