diff --git a/session.go b/session.go index abc8e22a7..9c2246713 100644 --- a/session.go +++ b/session.go @@ -795,6 +795,57 @@ type urlInfoOption struct { } func extractURL(s string) (*urlInfo, error) { + if strings.Contains(s, "mongodb+srv://") && strings.Contains(s, "mongodb.net") { + s = strings.TrimPrefix(s, "mongodb+srv://") + info := &urlInfo{options: []urlInfoOption{}} + + info.options = append(info.options, urlInfoOption{key: "ssl", value: "true"}) + info.options = append(info.options, urlInfoOption{key: "authSource", value: "admin"}) + info.options = append(info.options, urlInfoOption{key: "replicaSet", value: "Cluster0-shard-0"}) + + if c := strings.Index(s, "?"); c != -1 { + s = s[:c] + } + + if c := strings.Index(s, "@"); c != -1 { + pair := strings.SplitN(s[:c], ":", 2) + if len(pair) > 2 || pair[0] == "" { + return nil, errors.New("credentials must be provided as user:pass@host") + } + var err error + info.user, err = url.QueryUnescape(pair[0]) + if err != nil { + return nil, fmt.Errorf("cannot unescape username in URL: %q", pair[0]) + } + if len(pair) > 1 { + info.pass, err = url.QueryUnescape(pair[1]) + if err != nil { + return nil, fmt.Errorf("cannot unescape password in URL") + } + } + s = s[c+1:] + } + + if c := strings.Index(s, "/"); c != -1 { + info.db = s[c+1:] + s = s[:c] + } + + _, addrs, err := net.LookupSRV("mongodb", "tcp", s) + if err != nil { + info.addrs = strings.Split(s, ",") + return info, nil + } + + for _, addr := range addrs { + target := addr.Target + port := addr.Port + info.addrs = append(info.addrs, target[:len(target)-1]+":"+strconv.FormatUint(uint64(port), 10)) + } + + return info, nil + } + s = strings.TrimPrefix(s, "mongodb://") info := &urlInfo{options: []urlInfoOption{}}