File tree Expand file tree Collapse file tree 4 files changed +40
-35
lines changed Expand file tree Collapse file tree 4 files changed +40
-35
lines changed Original file line number Diff line number Diff line change 1
1
package settings
2
2
3
3
import (
4
- "errors"
5
4
"fmt"
6
5
"time"
7
6
@@ -30,21 +29,10 @@ func (d *DoH) setDefaults() {
30
29
d .Self .setDefaults ()
31
30
}
32
31
33
- var (
34
- ErrDoHProviderNotValid = errors .New ("DoH provider is not valid" )
35
- )
36
-
37
32
func (d * DoH ) validate () (err error ) {
38
- allProviders := provider .All ()
39
- allProvidersSet := make (map [string ]struct {}, len (allProviders ))
40
- for _ , provider := range allProviders {
41
- allProvidersSet [provider .Name ] = struct {}{}
42
- }
43
-
44
- for _ , provider := range d .DoHProviders {
45
- if _ , ok := allProvidersSet [provider ]; ! ok {
46
- return fmt .Errorf ("%w: %s" , ErrDoHProviderNotValid , provider )
47
- }
33
+ err = checkProviderNames (d .DoHProviders )
34
+ if err != nil {
35
+ return fmt .Errorf ("DoH provider: %w" , err )
48
36
}
49
37
50
38
const minTimeout = time .Millisecond
Original file line number Diff line number Diff line change @@ -43,24 +43,18 @@ func (d *DoT) setDefaults() {
43
43
}
44
44
45
45
var (
46
- ErrTimeoutTooSmall = errors .New ("timeout is too small" )
47
- ErrDoTProviderNotValid = errors .New ("DoT provider is not valid" )
48
- ErrDNSProviderNotValid = errors .New ("plaintext DNS provider is not valid" )
46
+ ErrTimeoutTooSmall = errors .New ("timeout is too small" )
49
47
)
50
48
51
49
func (d * DoT ) validate () (err error ) {
52
- allProvidersSet := allProvidersStringSet ()
53
-
54
- for _ , provider := range d .DoTProviders {
55
- if _ , ok := allProvidersSet [provider ]; ! ok {
56
- return fmt .Errorf ("%w: %s" , ErrDoTProviderNotValid , provider )
57
- }
50
+ err = checkProviderNames (d .DoTProviders )
51
+ if err != nil {
52
+ return fmt .Errorf ("DoT provider: %w" , err )
58
53
}
59
54
60
- for _ , provider := range d .DNSProviders {
61
- if _ , ok := allProvidersSet [provider ]; ! ok {
62
- return fmt .Errorf ("%w: %s" , ErrDNSProviderNotValid , provider )
63
- }
55
+ err = checkProviderNames (d .DNSProviders )
56
+ if err != nil {
57
+ return fmt .Errorf ("fallback DNS plaintext provider: %w" , err )
64
58
}
65
59
66
60
const minTimeout = time .Millisecond
Original file line number Diff line number Diff line change 1
1
package settings
2
2
3
3
func andStrings (strings []string ) (result string ) {
4
+ return joinStrings (strings , "and" )
5
+ }
6
+
7
+ func orStrings (strings []string ) (result string ) {
8
+ return joinStrings (strings , "or" )
9
+ }
10
+
11
+ func joinStrings (strings []string , lastJoin string ) (result string ) {
4
12
if len (strings ) == 0 {
5
13
return ""
6
14
}
@@ -10,7 +18,7 @@ func andStrings(strings []string) (result string) {
10
18
if i < len (strings )- 1 {
11
19
result += strings [i ] + ", "
12
20
} else {
13
- result += " and " + strings [i ]
21
+ result += " " + lastJoin + " " + strings [i ]
14
22
}
15
23
}
16
24
Original file line number Diff line number Diff line change @@ -40,11 +40,26 @@ func checkListeningAddress(address string) (err error) {
40
40
return err
41
41
}
42
42
43
- func allProvidersStringSet ( ) (set map [ string ] struct {} ) {
44
- providers := provider .All ()
45
- set = make (map [ string ] struct {} , len (providers ))
46
- for _ , provider := range providers {
47
- set [ provider . Name ] = struct {}{}
43
+ func checkProviderNames ( providerNames [] string ) (err error ) {
44
+ allProviders := provider .All ()
45
+ allProviderNames : = make ([] string , len (allProviders ))
46
+ for i , provider := range allProviders {
47
+ allProviderNames [ i ] = provider . Name
48
48
}
49
- return set
49
+
50
+ for _ , providerName := range providerNames {
51
+ valid := false
52
+ for _ , acceptedProviderName := range allProviderNames {
53
+ if strings .EqualFold (providerName , acceptedProviderName ) {
54
+ valid = true
55
+ break
56
+ }
57
+ }
58
+ if ! valid {
59
+ return fmt .Errorf ("%w: %q must be one of: %s" ,
60
+ ErrValueNotOneOf , providerName , orStrings (allProviderNames ))
61
+ }
62
+ }
63
+
64
+ return nil
50
65
}
You can’t perform that action at this time.
0 commit comments