Skip to content

Commit 36425d2

Browse files
authored
Tests: Improve geosite & geoip tests (#5502)
#5488 (comment)
1 parent 39ba1f7 commit 36425d2

File tree

6 files changed

+102
-91
lines changed

6 files changed

+102
-91
lines changed

app/router/condition_geoip_test.go

Lines changed: 26 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,17 @@
11
package router_test
22

33
import (
4-
"fmt"
54
"os"
65
"path/filepath"
6+
"runtime"
77
"testing"
88

99
"github.com/xtls/xray-core/app/router"
1010
"github.com/xtls/xray-core/common"
1111
"github.com/xtls/xray-core/common/net"
12-
"github.com/xtls/xray-core/common/platform"
13-
"github.com/xtls/xray-core/common/platform/filesystem"
14-
"google.golang.org/protobuf/proto"
12+
"github.com/xtls/xray-core/infra/conf"
1513
)
1614

17-
func getAssetPath(file string) (string, error) {
18-
path := platform.GetAssetLocation(file)
19-
_, err := os.Stat(path)
20-
if os.IsNotExist(err) {
21-
path := filepath.Join("..", "..", "resources", file)
22-
_, err := os.Stat(path)
23-
if os.IsNotExist(err) {
24-
return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file)
25-
}
26-
if err != nil {
27-
return "", fmt.Errorf("can't stat %s: %v", path, err)
28-
}
29-
return path, nil
30-
}
31-
if err != nil {
32-
return "", fmt.Errorf("can't stat %s: %v", path, err)
33-
}
34-
35-
return path, nil
36-
}
37-
3815
func TestGeoIPMatcher(t *testing.T) {
3916
cidrList := []*router.CIDR{
4017
{Ip: []byte{0, 0, 0, 0}, Prefix: 8},
@@ -182,12 +159,11 @@ func TestGeoIPReverseMatcher(t *testing.T) {
182159
}
183160

184161
func TestGeoIPMatcher4CN(t *testing.T) {
185-
ips, err := loadGeoIP("CN")
162+
geo := "geoip:cn"
163+
geoip, err := loadGeoIP(geo)
186164
common.Must(err)
187165

188-
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
189-
Cidr: ips,
190-
})
166+
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
191167
common.Must(err)
192168

193169
if matcher.Match([]byte{8, 8, 8, 8}) {
@@ -196,50 +172,46 @@ func TestGeoIPMatcher4CN(t *testing.T) {
196172
}
197173

198174
func TestGeoIPMatcher6US(t *testing.T) {
199-
ips, err := loadGeoIP("US")
175+
geo := "geoip:us"
176+
geoip, err := loadGeoIP(geo)
200177
common.Must(err)
201178

202-
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
203-
Cidr: ips,
204-
})
179+
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
205180
common.Must(err)
206181

207182
if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) {
208183
t.Error("expect US geoip contain 2001:4860:4860::8888, but actually not")
209184
}
210185
}
211186

212-
func loadGeoIP(country string) ([]*router.CIDR, error) {
213-
path, err := getAssetPath("geoip.dat")
214-
if err != nil {
215-
return nil, err
216-
}
217-
geoipBytes, err := filesystem.ReadFile(path)
187+
func loadGeoIP(geo string) (*router.GeoIP, error) {
188+
os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
189+
190+
geoip, err := conf.ToCidrList([]string{geo})
218191
if err != nil {
219192
return nil, err
220193
}
221194

222-
var geoipList router.GeoIPList
223-
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
224-
return nil, err
195+
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
196+
geoip, err = router.GetGeoIPList(geoip)
197+
if err != nil {
198+
return nil, err
199+
}
225200
}
226201

227-
for _, geoip := range geoipList.Entry {
228-
if geoip.CountryCode == country {
229-
return geoip.Cidr, nil
230-
}
202+
if len(geoip) == 0 {
203+
panic("country not found: " + geo)
231204
}
232205

233-
panic("country not found: " + country)
206+
return geoip[0], nil
234207
}
235208

236209
func BenchmarkGeoIPMatcher4CN(b *testing.B) {
237-
ips, err := loadGeoIP("CN")
210+
geo := "geoip:cn"
211+
geoip, err := loadGeoIP(geo)
238212
common.Must(err)
239213

240-
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
241-
Cidr: ips,
242-
})
214+
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
243215
common.Must(err)
244216

245217
b.ResetTimer()
@@ -250,12 +222,11 @@ func BenchmarkGeoIPMatcher4CN(b *testing.B) {
250222
}
251223

252224
func BenchmarkGeoIPMatcher6US(b *testing.B) {
253-
ips, err := loadGeoIP("US")
225+
geo := "geoip:us"
226+
geoip, err := loadGeoIP(geo)
254227
common.Must(err)
255228

256-
matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
257-
Cidr: ips,
258-
})
229+
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
259230
common.Must(err)
260231

261232
b.ResetTimer()

app/router/condition_test.go

Lines changed: 65 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
11
package router_test
22

33
import (
4+
"os"
5+
"path/filepath"
6+
"runtime"
47
"strconv"
58
"testing"
69

10+
"github.com/xtls/xray-core/app/router"
711
. "github.com/xtls/xray-core/app/router"
812
"github.com/xtls/xray-core/common"
9-
"github.com/xtls/xray-core/common/errors"
1013
"github.com/xtls/xray-core/common/net"
11-
"github.com/xtls/xray-core/common/platform/filesystem"
1214
"github.com/xtls/xray-core/common/protocol"
1315
"github.com/xtls/xray-core/common/protocol/http"
1416
"github.com/xtls/xray-core/common/session"
1517
"github.com/xtls/xray-core/features/routing"
1618
routing_session "github.com/xtls/xray-core/features/routing/session"
17-
"google.golang.org/protobuf/proto"
19+
"github.com/xtls/xray-core/infra/conf"
1820
)
1921

2022
func withBackground() routing.Context {
@@ -300,32 +302,25 @@ func TestRoutingRule(t *testing.T) {
300302
}
301303
}
302304

303-
func loadGeoSite(country string) ([]*Domain, error) {
304-
path, err := getAssetPath("geosite.dat")
305-
if err != nil {
306-
return nil, err
307-
}
308-
geositeBytes, err := filesystem.ReadFile(path)
309-
if err != nil {
310-
return nil, err
311-
}
305+
func loadGeoSiteDomains(geo string) ([]*Domain, error) {
306+
os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))
312307

313-
var geositeList GeoSiteList
314-
if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil {
308+
domains, err := conf.ParseDomainRule(geo)
309+
if err != nil {
315310
return nil, err
316311
}
317312

318-
for _, site := range geositeList.Entry {
319-
if site.CountryCode == country {
320-
return site.Domain, nil
313+
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
314+
domains, err = router.GetDomainList(domains)
315+
if err != nil {
316+
return nil, err
321317
}
322318
}
323-
324-
return nil, errors.New("country not found: " + country)
319+
return domains, nil
325320
}
326321

327322
func TestChinaSites(t *testing.T) {
328-
domains, err := loadGeoSite("CN")
323+
domains, err := loadGeoSiteDomains("geosite:cn")
329324
common.Must(err)
330325

331326
acMatcher, err := NewMphMatcherGroup(domains)
@@ -366,8 +361,50 @@ func TestChinaSites(t *testing.T) {
366361
}
367362
}
368363

364+
func TestChinaSitesWithAttrs(t *testing.T) {
365+
domains, err := loadGeoSiteDomains("geosite:google@cn")
366+
common.Must(err)
367+
368+
acMatcher, err := NewMphMatcherGroup(domains)
369+
common.Must(err)
370+
371+
type TestCase struct {
372+
Domain string
373+
Output bool
374+
}
375+
testCases := []TestCase{
376+
{
377+
Domain: "google.cn",
378+
Output: true,
379+
},
380+
{
381+
Domain: "recaptcha.net",
382+
Output: true,
383+
},
384+
{
385+
Domain: "164.com",
386+
Output: false,
387+
},
388+
{
389+
Domain: "164.com",
390+
Output: false,
391+
},
392+
}
393+
394+
for i := 0; i < 1024; i++ {
395+
testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false})
396+
}
397+
398+
for _, testCase := range testCases {
399+
r := acMatcher.ApplyDomain(testCase.Domain)
400+
if r != testCase.Output {
401+
t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r)
402+
}
403+
}
404+
}
405+
369406
func BenchmarkMphDomainMatcher(b *testing.B) {
370-
domains, err := loadGeoSite("CN")
407+
domains, err := loadGeoSiteDomains("geosite:cn")
371408
common.Must(err)
372409

373410
matcher, err := NewMphMatcherGroup(domains)
@@ -412,11 +449,11 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
412449
var geoips []*GeoIP
413450

414451
{
415-
ips, err := loadGeoIP("CN")
452+
ips, err := loadGeoIP("geoip:cn")
416453
common.Must(err)
417454
geoips = append(geoips, &GeoIP{
418455
CountryCode: "CN",
419-
Cidr: ips,
456+
Cidr: ips.Cidr,
420457
})
421458
}
422459

@@ -425,25 +462,25 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
425462
common.Must(err)
426463
geoips = append(geoips, &GeoIP{
427464
CountryCode: "JP",
428-
Cidr: ips,
465+
Cidr: ips.Cidr,
429466
})
430467
}
431468

432469
{
433-
ips, err := loadGeoIP("CA")
470+
ips, err := loadGeoIP("geoip:ca")
434471
common.Must(err)
435472
geoips = append(geoips, &GeoIP{
436473
CountryCode: "CA",
437-
Cidr: ips,
474+
Cidr: ips.Cidr,
438475
})
439476
}
440477

441478
{
442-
ips, err := loadGeoIP("US")
479+
ips, err := loadGeoIP("geoip:us")
443480
common.Must(err)
444481
geoips = append(geoips, &GeoIP{
445482
CountryCode: "US",
446-
Cidr: ips,
483+
Cidr: ips.Cidr,
447484
})
448485
}
449486

app/router/config.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
112112
domains := rr.Domain
113113
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
114114
var err error
115-
domains, err = getDomainList(rr.Domain)
115+
domains, err = GetDomainList(rr.Domain)
116116
if err != nil {
117117
return nil, errors.New("failed to build domains from mmap").Base(err)
118118
}
@@ -122,7 +122,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
122122
if err != nil {
123123
return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err)
124124
}
125-
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)")
125+
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(domains), " domain rule(s)")
126126
conds.Add(matcher)
127127
}
128128

@@ -214,7 +214,7 @@ func GetGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {
214214

215215
}
216216

217-
func getDomainList(domains []*Domain) ([]*Domain, error) {
217+
func GetDomainList(domains []*Domain) ([]*Domain, error) {
218218
domainList := []*Domain{}
219219
for _, domain := range domains {
220220
val := strings.Split(domain.Value, "_")

common/platform/windows.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33

44
package platform
55

6-
import "path/filepath"
6+
import (
7+
"path/filepath"
8+
)
79

810
func LineSeparator() string {
911
return "\r\n"
@@ -12,6 +14,7 @@ func LineSeparator() string {
1214
// GetAssetLocation searches for `file` in the env dir and the executable dir
1315
func GetAssetLocation(file string) string {
1416
assetPath := NewEnvFlag(AssetLocation).GetValue(getExecutableDir)
17+
1518
return filepath.Join(assetPath, file)
1619
}
1720

infra/conf/dns.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) {
8989
var originalRules []*dns.NameServer_OriginalRule
9090

9191
for _, rule := range c.Domains {
92-
parsedDomain, err := parseDomainRule(rule)
92+
parsedDomain, err := ParseDomainRule(rule)
9393
if err != nil {
9494
return nil, errors.New("invalid domain rule: ", rule).Base(err)
9595
}

infra/conf/router.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er
291291
return filteredDomains, nil
292292
}
293293

294-
func parseDomainRule(domain string) ([]*router.Domain, error) {
294+
func ParseDomainRule(domain string) ([]*router.Domain, error) {
295295
if strings.HasPrefix(domain, "geosite:") {
296296
country := strings.ToUpper(domain[8:])
297297
domains, err := loadGeositeWithAttr("geosite.dat", country)
@@ -489,7 +489,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
489489

490490
if rawFieldRule.Domain != nil {
491491
for _, domain := range *rawFieldRule.Domain {
492-
rules, err := parseDomainRule(domain)
492+
rules, err := ParseDomainRule(domain)
493493
if err != nil {
494494
return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
495495
}
@@ -499,7 +499,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {
499499

500500
if rawFieldRule.Domains != nil {
501501
for _, domain := range *rawFieldRule.Domains {
502-
rules, err := parseDomainRule(domain)
502+
rules, err := ParseDomainRule(domain)
503503
if err != nil {
504504
return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
505505
}

0 commit comments

Comments
 (0)