diff --git a/cmd/daemon/main.go b/cmd/daemon/main.go index 275dcc8d..a2ea2d38 100644 --- a/cmd/daemon/main.go +++ b/cmd/daemon/main.go @@ -467,13 +467,11 @@ func main() { ) filesharePortController := meshnet.NewPortAccessController( - cfgMgr, netw, - meshRegistry, meshnet.NewProcessChecker(), ) fileshareProcMonitor := meshnet.NewProcMonitor( - &filesharePortController, + filesharePortController, netlinkMonitorSetupFn, ) diff --git a/daemon/jobs_test.go b/daemon/jobs_test.go index ee9692eb..f53b1fcd 100644 --- a/daemon/jobs_test.go +++ b/daemon/jobs_test.go @@ -129,6 +129,10 @@ func (n *meshNetworker) AllowFileshare(address meshnet.UniqueAddress) error { return nil } +func (n *meshNetworker) PermitFileshare() error { + return nil +} + func (n *meshNetworker) AllowIncoming(address meshnet.UniqueAddress, lanAllowed bool) error { n.allowedIncoming = append(n.allowedIncoming, address) return nil @@ -144,6 +148,10 @@ func (n *meshNetworker) BlockFileshare(address meshnet.UniqueAddress) error { return nil } +func (n *meshNetworker) ForbidFileshare() error { + return nil +} + func (*meshNetworker) ResetRouting(mesh.MachinePeer, mesh.MachinePeers) error { return nil } func (*meshNetworker) BlockRouting(meshnet.UniqueAddress) error { return nil } func (*meshNetworker) Refresh(mesh.MachineMap) error { return nil } diff --git a/meshnet/monitor_event_handler.go b/meshnet/monitor_event_handler.go index 0e706bfe..e7952db0 100644 --- a/meshnet/monitor_event_handler.go +++ b/meshnet/monitor_event_handler.go @@ -2,34 +2,24 @@ package meshnet import ( "errors" - "fmt" "log" "sync" - "github.com/NordSecurity/nordvpn-linux/config" - "github.com/NordSecurity/nordvpn-linux/core/mesh" "github.com/NordSecurity/nordvpn-linux/internal" ) var ErrIncorrectCmdlineContent = errors.New("invalid content of cmdline file of /proc") -// FilesharePortAccessController blocks or allows fileshare port when -// fileshare process stopped or was restarted accordingly. +// FilesharePortAccessController forbids or permits fileshare port +// use when fileshare process stopped or was restarted accordingly. type FilesharePortAccessController struct { - cm config.Manager - netw Networker - reg mesh.Registry + netw FileshareNetworker filesharePID PID processChecker ProcessChecker mu sync.Mutex } -func NewPortAccessController( - cm config.Manager, - netw Networker, - reg mesh.Registry, - pc ProcessChecker, -) FilesharePortAccessController { +func NewPortAccessController(netw FileshareNetworker, pc ProcessChecker) *FilesharePortAccessController { filesharePID := PID(0) // NOTE:if the fileshare is already running, set the initial PID. // This can happen only when the daemon was restarted, but nordfileshare @@ -38,10 +28,8 @@ func NewPortAccessController( if PID != nil { filesharePID = *PID } - return FilesharePortAccessController{ - cm: cm, + return &FilesharePortAccessController{ netw: netw, - reg: reg, filesharePID: filesharePID, processChecker: pc, } @@ -70,76 +58,19 @@ func (eventHandler *FilesharePortAccessController) OnProcessStarted(ev ProcEvent log.Println(internal.InfoPrefix, "updating fileshare process pid to:", ev.PID) eventHandler.filesharePID = ev.PID - go eventHandler.allowFileshare() -} - -func (eventHandler *FilesharePortAccessController) allowFileshare() error { - log.Println(internal.InfoPrefix, "allowing fileshare port") - - eventHandler.mu.Lock() - defer eventHandler.mu.Unlock() - - peers, err := eventHandler.listPeers() - if err != nil { - return err - } - - for _, peer := range peers { - peerUniqAddr := UniqueAddress{UID: peer.PublicKey, Address: peer.Address} - if err := eventHandler.netw.AllowFileshare(peerUniqAddr); err != nil { - return err - } - } - - return nil -} - -func (eventHandler *FilesharePortAccessController) listPeers() (mesh.MachinePeers, error) { - var cfg config.Config - if err := eventHandler.cm.Load(&cfg); err != nil { - return nil, fmt.Errorf("reading configuration when listing peers: %w", err) - } - - if cfg.MeshDevice == nil { - return nil, fmt.Errorf("meshnet is not configured") - } - - token := cfg.TokensData[cfg.AutoConnectData.ID].Token - peers, err := eventHandler.reg.List(token, cfg.MeshDevice.ID) - if err != nil { - return nil, fmt.Errorf("listing peers: %w", err) - } - return peers, nil + go eventHandler.netw.PermitFileshare() } func (eventHandler *FilesharePortAccessController) OnProcessStopped(ev ProcEvent) { + if eventHandler.filesharePID == 0 { + return + } if eventHandler.filesharePID != ev.PID { return } log.Println(internal.InfoPrefix, "resetting fileshare pid") eventHandler.filesharePID = 0 - go eventHandler.blockFileshare() -} - -func (eventHandler *FilesharePortAccessController) blockFileshare() error { - log.Println(internal.InfoPrefix, "blocking fileshare port") - - eventHandler.mu.Lock() - defer eventHandler.mu.Unlock() - - peers, err := eventHandler.listPeers() - if err != nil { - return err - } - - for _, peer := range peers { - peerUniqAddr := UniqueAddress{UID: peer.PublicKey, Address: peer.Address} - if err := eventHandler.netw.BlockFileshare(peerUniqAddr); err != nil { - return err - } - } - - return nil + go eventHandler.netw.ForbidFileshare() } // ProcessChecker represents process-related utilities @@ -148,3 +79,9 @@ type ProcessChecker interface { GiveProcessPID(string) *PID CurrentPID() PID } + +// FileshareNetworker represents ability of a networker to permit or forbid fileshare +type FileshareNetworker interface { + PermitFileshare() error + ForbidFileshare() error +} diff --git a/meshnet/monitor_event_handler_test.go b/meshnet/monitor_event_handler_test.go index bcf094ef..37866c57 100644 --- a/meshnet/monitor_event_handler_test.go +++ b/meshnet/monitor_event_handler_test.go @@ -1,14 +1,15 @@ package meshnet import ( + "sync" "testing" + "time" "github.com/NordSecurity/nordvpn-linux/test/category" - "github.com/NordSecurity/nordvpn-linux/test/mock" "github.com/stretchr/testify/assert" ) -func TestEventHandler_OnProcessStarted(t *testing.T) { +func TestOnProcessStarted_ManagesPID(t *testing.T) { category.Set(t, category.Unit) tests := []struct { @@ -23,14 +24,14 @@ func TestEventHandler_OnProcessStarted(t *testing.T) { savedPID: PID(0), procEvent: ProcEvent{1337}, finalPID: PID(1337), - pc: procCheckerStub{isFileshare: true, currentPID: 1336}, // currentPID lower than the PID from event + pc: procCheckerStub{isFileshare: true, daemonPID: 1336}, // currentPID lower than the PID from event }, { name: "PID is not updated if the event's PID is older than current process PID", savedPID: PID(0), procEvent: ProcEvent{1337}, finalPID: PID(0), - pc: procCheckerStub{isFileshare: true, currentPID: 1338}, // currentPID higher than the PID from event + pc: procCheckerStub{isFileshare: true, daemonPID: 1338}, // currentPID higher than the PID from event }, { name: "PID is not updated if process is NOT fileshare", @@ -51,8 +52,7 @@ func TestEventHandler_OnProcessStarted(t *testing.T) { pac := FilesharePortAccessController{ filesharePID: tt.savedPID, processChecker: tt.pc, - cm: mock.NewMockConfigManager(), - reg: &mock.RegistryMock{}, + netw: fileshareNetworkerDummy{}, } assert.Equal(t, tt.savedPID, pac.filesharePID) @@ -63,7 +63,74 @@ func TestEventHandler_OnProcessStarted(t *testing.T) { } } -func TestEventHandler_OnProcessStopped(t *testing.T) { +func TestOnProcessStarted_PermitsFileshareWhenProcessStarted(t *testing.T) { + category.Set(t, category.Unit) + fileshareNetworker := newNetworkerSpy() + pac := FilesharePortAccessController{ + filesharePID: PID(0), + processChecker: procCheckerStub{ + isFileshare: true, // detects every new process as fileshare process + daemonPID: 1336, + }, + netw: fileshareNetworker, + } + // new fileshare process event appeared and it's younger than daemonPID + newEvent := ProcEvent{1337} + assert.Equal(t, PID(0), pac.filesharePID) + assert.False(t, fileshareNetworker.permitCalled) + + pac.OnProcessStarted(newEvent) + fileshareNetworker.waitForPermitCall(t) + + assert.True(t, fileshareNetworker.permitCalled) +} + +func TestOnProcessStarted_DoesNotPermitFileshare(t *testing.T) { + category.Set(t, category.Unit) + + tests := []struct { + name string + savedPID PID + procEvent ProcEvent + pc ProcessChecker + }{ + { + name: "when the event's PID is older than current process PID", + savedPID: PID(0), + procEvent: ProcEvent{1337}, + pc: procCheckerStub{isFileshare: true, daemonPID: 1338}, // daemon PID higher than the PID from event + }, + { + name: "when process is NOT fileshare", + savedPID: PID(0), + pc: procCheckerStub{isFileshare: false}, + }, + { + name: "when fileshare was already permitted", + savedPID: PID(1337), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fileshareNetworker := newNetworkerSpy() + pac := FilesharePortAccessController{ + filesharePID: tt.savedPID, + processChecker: tt.pc, + netw: fileshareNetworker, + } + assert.Equal(t, tt.savedPID, pac.filesharePID) + assert.False(t, fileshareNetworker.permitCalled) + + pac.OnProcessStarted(tt.procEvent) + fileshareNetworker.ensurePermitNotCalled(t, 100*time.Microsecond) + + assert.False(t, fileshareNetworker.permitCalled) + }) + } +} + +func TestOnProcessStopped_ManagesPID(t *testing.T) { category.Set(t, category.Unit) tests := []struct { @@ -92,8 +159,7 @@ func TestEventHandler_OnProcessStopped(t *testing.T) { pac := FilesharePortAccessController{ filesharePID: tt.savedPID, processChecker: procCheckerStub{isFileshare: notImportant}, - cm: mock.NewMockConfigManager(), - reg: &mock.RegistryMock{}, + netw: fileshareNetworkerDummy{}, } assert.Equal(t, tt.savedPID, pac.filesharePID) @@ -104,9 +170,68 @@ func TestEventHandler_OnProcessStopped(t *testing.T) { } } +func TestOnProcessStopped_ForbidsFileshareWhenProcessStopped(t *testing.T) { + category.Set(t, category.Unit) + fileshareNetworker := newNetworkerSpy() + pac := FilesharePortAccessController{ + filesharePID: PID(1337), + // detects every new process as fileshare process + processChecker: procCheckerStub{isFileshare: true}, + netw: fileshareNetworker, + } + // new fileshare process event appeared with PID the same as recorded fileshare PID + newEvent := ProcEvent{1337} + assert.Equal(t, PID(1337), pac.filesharePID) + assert.False(t, fileshareNetworker.forbidCalled) + + pac.OnProcessStopped(newEvent) + fileshareNetworker.waitForForbidCall(t) + + assert.True(t, fileshareNetworker.forbidCalled) +} + +func TestOnProcessStarted_DoesNotForbidFileshare(t *testing.T) { + category.Set(t, category.Unit) + + tests := []struct { + name string + savedPID PID + procEvent ProcEvent + }{ + { + name: "when the event's PID does not match saved fileshare PID", + savedPID: PID(1337), + procEvent: ProcEvent{666}, + }, + { + name: "when fileshare was already forbidden", + savedPID: PID(0), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fileshareNetworker := newNetworkerSpy() + pac := FilesharePortAccessController{ + filesharePID: tt.savedPID, + processChecker: procCheckerStub{isFileshare: true}, + netw: fileshareNetworker, + } + assert.Equal(t, tt.savedPID, pac.filesharePID) + assert.False(t, fileshareNetworker.forbidCalled) + + pac.OnProcessStopped(tt.procEvent) + fileshareNetworker.ensureForbidNotCalled(t, 100*time.Microsecond) + + assert.False(t, fileshareNetworker.permitCalled) + }) + } +} + +// procChecker type procCheckerStub struct { isFileshare bool - currentPID PID + daemonPID PID } func (pc procCheckerStub) IsFileshareProcess(PID) bool { @@ -118,5 +243,89 @@ func (pu procCheckerStub) GiveProcessPID(string) *PID { } func (pc procCheckerStub) CurrentPID() PID { - return pc.currentPID + return pc.daemonPID +} + +// fileshareNetworker spy +type fileshareNetworkerSpy struct { + permitCh chan struct{} + permitCalled bool + forbidCh chan struct{} + forbidCalled bool +} + +func newNetworkerSpy() *fileshareNetworkerSpy { + var wg sync.WaitGroup + wg.Add(1) + return &fileshareNetworkerSpy{ + permitCh: make(chan struct{}), + permitCalled: false, + forbidCh: make(chan struct{}), + forbidCalled: false, + } +} + +func (fn *fileshareNetworkerSpy) PermitFileshare() error { + fn.permitCh <- struct{}{} + return nil +} + +func (fn *fileshareNetworkerSpy) waitForPermitCall(t *testing.T) { + t.Helper() + select { + case <-fn.permitCh: + fn.permitCalled = true + return + case <-time.After(time.Second): + t.Fatal("fileshare should be permitted but was not") + } +} + +func (fn *fileshareNetworkerSpy) ensurePermitNotCalled(t *testing.T, d time.Duration) { + t.Helper() + select { + case <-fn.permitCh: + t.Fatal("fileshare should NOT be permitted but was") + return + case <-time.After(d): + // OK + } +} + +func (fn *fileshareNetworkerSpy) ForbidFileshare() error { + fn.forbidCh <- struct{}{} + return nil +} + +func (fn *fileshareNetworkerSpy) waitForForbidCall(t *testing.T) { + t.Helper() + select { + case <-fn.forbidCh: + fn.forbidCalled = true + return + case <-time.After(time.Second): + t.Fatal("fileshare should be forbidden but was not") + } +} + +func (fn *fileshareNetworkerSpy) ensureForbidNotCalled(t *testing.T, d time.Duration) { + t.Helper() + select { + case <-fn.forbidCh: + t.Fatal("fileshare should NOT be forbidden but was") + return + case <-time.After(d): + // OK + } +} + +// fileshareNetworker dummy +type fileshareNetworkerDummy struct{} + +func (fn fileshareNetworkerDummy) PermitFileshare() error { + return nil +} + +func (fn fileshareNetworkerDummy) ForbidFileshare() error { + return nil } diff --git a/meshnet/networker.go b/meshnet/networker.go index 3942cf84..57e4743e 100644 --- a/meshnet/networker.go +++ b/meshnet/networker.go @@ -28,8 +28,12 @@ type Networker interface { BlockIncoming(UniqueAddress) error // AllowFileshare creates a rule enabling fileshare port for the given address AllowFileshare(UniqueAddress) error + // PermitFileshare creates a rules enabling fileshare port for all available peers and sets fileshare as permitted + PermitFileshare() error // BlockFileshare removes a rule enabling fileshare port for the given address if it exists BlockFileshare(UniqueAddress) error + // ForbidFileshare removes a rules enabling fileshare port for all available peers and sets fileshare as forbidden + ForbidFileshare() error // ResetRouting is used when there are routing setting changes, // except when routing is denied - then BlockRouting must be used. changedPeer is the peer whose routing settings // changed, peers is the map of all the machine peers(including the changed peer). diff --git a/meshnet/process_checker_test.go b/meshnet/process_checker_test.go index f908751f..c2faac60 100644 --- a/meshnet/process_checker_test.go +++ b/meshnet/process_checker_test.go @@ -12,7 +12,6 @@ import ( "github.com/stretchr/testify/assert" ) -// XXX: Update tests (including name) func TestGiveProcessPID(t *testing.T) { category.Set(t, category.Unit) diff --git a/meshnet/server_test.go b/meshnet/server_test.go index 5218d490..244653ea 100644 --- a/meshnet/server_test.go +++ b/meshnet/server_test.go @@ -94,6 +94,10 @@ func (n *workingNetworker) AllowFileshare(address UniqueAddress) error { return nil } +func (n *workingNetworker) PermitFileshare() error { + return nil +} + func (n *workingNetworker) AllowIncoming(address UniqueAddress, lanAllowed bool) error { n.allowedIncoming = append(n.allowedIncoming, allowedIncoming{ address: address, @@ -113,6 +117,10 @@ func (n *workingNetworker) BlockFileshare(address UniqueAddress) error { return nil } +func (n *workingNetworker) ForbidFileshare() error { + return nil +} + func (n *workingNetworker) ResetRouting(changedPeer mesh.MachinePeer, peer mesh.MachinePeers) error { n.resetPeers = append(n.resetPeers, changedPeer.PublicKey) diff --git a/networker/networker.go b/networker/networker.go index fb93a89a..4da58291 100644 --- a/networker/networker.go +++ b/networker/networker.go @@ -168,7 +168,8 @@ type Combined struct { enableLocalTraffic bool // list with the existing OS interfaces when VPN was connected. // This is used at network changes to know when a new interface was inserted - interfaces mapset.Set[string] + interfaces mapset.Set[string] + isFilesharePermitted bool } // NewCombined returns a ready made version of @@ -1509,6 +1510,11 @@ func (netw *Combined) AllowFileshare(uniqueAddress meshnet.UniqueAddress) error } func (netw *Combined) allowFileshare(publicKey string, address netip.Addr) error { + if !netw.isFilesharePermitted { + log.Println(internal.WarningPrefix, "fileshare is not permitted, can't add allow rules") + return nil + } + ruleName := publicKey + "-allow-fileshare-rule-" + address.String() rules := []firewall.Rule{{ Name: ruleName, @@ -1537,6 +1543,24 @@ func (netw *Combined) allowFileshare(publicKey string, address netip.Addr) error return nil } +func (netw *Combined) PermitFileshare() error { + netw.mu.Lock() + defer netw.mu.Unlock() + netw.isFilesharePermitted = true + return netw.allowFileshareAll() +} + +func (netw *Combined) allowFileshareAll() error { + var allErrors []error + for _, peer := range netw.cfg.Peers { + if peer.DoIAllowFileshare { + err := netw.allowFileshare(peer.PublicKey, peer.Address) + allErrors = append(allErrors, err) + } + } + return errors.Join(allErrors...) +} + func (netw *Combined) undenyDNS() error { ruleName := "deny-private-dns" @@ -1608,7 +1632,11 @@ func (netw *Combined) blockIncoming(uniqueAddress meshnet.UniqueAddress) error { func (netw *Combined) BlockFileshare(uniqueAddress meshnet.UniqueAddress) error { netw.mu.Lock() defer netw.mu.Unlock() - ruleName := uniqueAddress.UID + "-allow-fileshare-rule-" + uniqueAddress.Address.String() + return netw.blockFileshare(uniqueAddress.UID, uniqueAddress.Address) +} + +func (netw *Combined) blockFileshare(publicKey string, address netip.Addr) error { + ruleName := publicKey + "-allow-fileshare-rule-" + address.String() return netw.removeRule(ruleName) } @@ -1627,6 +1655,22 @@ func (netw *Combined) removeRule(ruleName string) error { return nil } +func (netw *Combined) ForbidFileshare() error { + netw.mu.Lock() + defer netw.mu.Unlock() + netw.isFilesharePermitted = false + return netw.blockFileshareAll() +} + +func (netw *Combined) blockFileshareAll() error { + var allErrors []error + for _, peer := range netw.cfg.Peers { + err := netw.blockFileshare(peer.PublicKey, peer.Address) + allErrors = append(allErrors, err) + } + return errors.Join(allErrors...) +} + func getHostsFromConfig(peers mesh.MachinePeers) dns.Hosts { hosts := make(dns.Hosts, 0, len(peers)) for _, peer := range peers { diff --git a/networker/networker_test.go b/networker/networker_test.go index 8fffd4ef..3d2f3316 100644 --- a/networker/networker_test.go +++ b/networker/networker_test.go @@ -1872,7 +1872,16 @@ func TestCombined_Refresh(t *testing.T) { for _, rule := range fw.rules { ruleNames = append(ruleNames, rule.Name) } - assert.Equal(t, 6, len(fw.rules), "%d firewall rules were configured, expected 5, rules content: \n%s", + // fileshare not permitted, so its rules are not added + assert.Equal(t, 5, len(fw.rules), "%d firewall rules were configured, expected 5, rules content: \n%s", + len(fw.rules), + strings.Join(ruleNames, "\n")) + + netw.isFilesharePermitted = true + netw.Refresh(machineMap) + + // fileshare permitted now, so it's rules are added + assert.Equal(t, 6, len(fw.rules), "%d firewall rules were configured, expected 6, rules content: \n%s", len(fw.rules), strings.Join(ruleNames, "\n")) @@ -2426,3 +2435,23 @@ func TestResetRouting(t *testing.T) { }) } } + +func TestCombined_PermitFileshare_SetsIsFileshareAllowedToTrue(t *testing.T) { + category.Set(t, category.Unit) + netw := GetTestCombined() + assert.False(t, netw.isFilesharePermitted) + + netw.PermitFileshare() + + assert.True(t, netw.isFilesharePermitted) +} + +func TestCombined_ForbidFileshare_SetIsFileshareAllowedToFalse(t *testing.T) { + category.Set(t, category.Unit) + netw := GetTestCombined() + netw.isFilesharePermitted = true + + netw.ForbidFileshare() + + assert.False(t, netw.isFilesharePermitted) +}