Skip to content

Commit 3a2fbd7

Browse files
committed
fix(settings): validation for DoH and DoT providers
1 parent 94e5502 commit 3a2fbd7

File tree

4 files changed

+40
-35
lines changed

4 files changed

+40
-35
lines changed

internal/config/settings/doh.go

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package settings
22

33
import (
4-
"errors"
54
"fmt"
65
"time"
76

@@ -30,21 +29,10 @@ func (d *DoH) setDefaults() {
3029
d.Self.setDefaults()
3130
}
3231

33-
var (
34-
ErrDoHProviderNotValid = errors.New("DoH provider is not valid")
35-
)
36-
3732
func (d *DoH) validate() (err error) {
38-
allProviders := provider.All()
39-
allProvidersSet := make(map[string]struct{}, len(allProviders))
40-
for _, provider := range allProviders {
41-
allProvidersSet[provider.Name] = struct{}{}
42-
}
43-
44-
for _, provider := range d.DoHProviders {
45-
if _, ok := allProvidersSet[provider]; !ok {
46-
return fmt.Errorf("%w: %s", ErrDoHProviderNotValid, provider)
47-
}
33+
err = checkProviderNames(d.DoHProviders)
34+
if err != nil {
35+
return fmt.Errorf("DoH provider: %w", err)
4836
}
4937

5038
const minTimeout = time.Millisecond

internal/config/settings/dot.go

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -43,24 +43,18 @@ func (d *DoT) setDefaults() {
4343
}
4444

4545
var (
46-
ErrTimeoutTooSmall = errors.New("timeout is too small")
47-
ErrDoTProviderNotValid = errors.New("DoT provider is not valid")
48-
ErrDNSProviderNotValid = errors.New("plaintext DNS provider is not valid")
46+
ErrTimeoutTooSmall = errors.New("timeout is too small")
4947
)
5048

5149
func (d *DoT) validate() (err error) {
52-
allProvidersSet := allProvidersStringSet()
53-
54-
for _, provider := range d.DoTProviders {
55-
if _, ok := allProvidersSet[provider]; !ok {
56-
return fmt.Errorf("%w: %s", ErrDoTProviderNotValid, provider)
57-
}
50+
err = checkProviderNames(d.DoTProviders)
51+
if err != nil {
52+
return fmt.Errorf("DoT provider: %w", err)
5853
}
5954

60-
for _, provider := range d.DNSProviders {
61-
if _, ok := allProvidersSet[provider]; !ok {
62-
return fmt.Errorf("%w: %s", ErrDNSProviderNotValid, provider)
63-
}
55+
err = checkProviderNames(d.DNSProviders)
56+
if err != nil {
57+
return fmt.Errorf("fallback DNS plaintext provider: %w", err)
6458
}
6559

6660
const minTimeout = time.Millisecond

internal/config/settings/helpers.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
package settings
22

33
func andStrings(strings []string) (result string) {
4+
return joinStrings(strings, "and")
5+
}
6+
7+
func orStrings(strings []string) (result string) {
8+
return joinStrings(strings, "or")
9+
}
10+
11+
func joinStrings(strings []string, lastJoin string) (result string) {
412
if len(strings) == 0 {
513
return ""
614
}
@@ -10,7 +18,7 @@ func andStrings(strings []string) (result string) {
1018
if i < len(strings)-1 {
1119
result += strings[i] + ", "
1220
} else {
13-
result += " and " + strings[i]
21+
result += " " + lastJoin + " " + strings[i]
1422
}
1523
}
1624

internal/config/settings/validation.go

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,26 @@ func checkListeningAddress(address string) (err error) {
4040
return err
4141
}
4242

43-
func allProvidersStringSet() (set map[string]struct{}) {
44-
providers := provider.All()
45-
set = make(map[string]struct{}, len(providers))
46-
for _, provider := range providers {
47-
set[provider.Name] = struct{}{}
43+
func checkProviderNames(providerNames []string) (err error) {
44+
allProviders := provider.All()
45+
allProviderNames := make([]string, len(allProviders))
46+
for i, provider := range allProviders {
47+
allProviderNames[i] = provider.Name
4848
}
49-
return set
49+
50+
for _, providerName := range providerNames {
51+
valid := false
52+
for _, acceptedProviderName := range allProviderNames {
53+
if strings.EqualFold(providerName, acceptedProviderName) {
54+
valid = true
55+
break
56+
}
57+
}
58+
if !valid {
59+
return fmt.Errorf("%w: %q must be one of: %s",
60+
ErrValueNotOneOf, providerName, orStrings(allProviderNames))
61+
}
62+
}
63+
64+
return nil
5065
}

0 commit comments

Comments
 (0)