Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 22 additions & 6 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,26 @@ func (c *Config) InitDefaults() {
}
}

// dsn is a parsed "scheme://address" RPC listen string.
type dsn struct {
scheme string
addr string
}

// parseDSN splits a "scheme://address" listen string into its scheme and
// address. It errors unless the string contains exactly one "://" separator.
func parseDSN(listen string) (dsn, error) {
scheme, addr, ok := strings.Cut(listen, "://")
if !ok || strings.Contains(addr, "://") {
return dsn{}, errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)")
}
return dsn{scheme: scheme, addr: addr}, nil
}

// Valid returns nil if config is valid.
func (c *Config) Valid() error {
if dsn := strings.Split(c.Listen, "://"); len(dsn) != 2 {
return errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)")
if _, err := parseDSN(c.Listen); err != nil {
return err
}
if c.RequestTimeout < 0 {
return errors.New("rpc request_timeout must be non-negative")
Expand All @@ -63,10 +79,10 @@ func (c *Config) Listener() (net.Listener, error) {

// Dialer creates rpc socket Dialer.
func (c *Config) Dialer() (net.Conn, error) {
dsn := strings.Split(c.Listen, "://")
if len(dsn) != 2 {
return nil, errors.New("invalid socket DSN (tcp://:6001, unix://file.sock)")
parsed, err := parseDSN(c.Listen)
if err != nil {
return nil, err
}
var d net.Dialer
return d.DialContext(context.Background(), dsn[0], dsn[1])
return d.DialContext(context.Background(), parsed.scheme, parsed.addr)
}
4 changes: 4 additions & 0 deletions go.work.sum
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOM
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/crypto v0.51.0 h1:IBPXwPfKxY7cWQZ38ZCIRPI50YLeevDLlLnyC5wRGTI=
golang.org/x/crypto v0.51.0/go.mod h1:8AdwkbraGNABw2kOX6YFPs3WM22XqI4EXEd8g+x7Oc8=
golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6 h1:QE6XYQK6naiK1EPAe1g/ILLxN5RBoH5xkJk3CqlMI/Y=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b h1:+qEpEAPhDZ1o0x3tHzZTQDArnOixOzGD9HUJfcg0mb4=
golang.org/x/lint v0.0.0-20201208152925-83fdc39ff7b5 h1:2M3HP5CCK1Si9FQhwnzYhXdG6DXeebvUHFpre8QvbyI=
Expand Down Expand Up @@ -416,6 +418,8 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1 h1:v+OssWQX+hTHEmOBgwxdZxK4zHq3yOs8F9J7mk0PY8E=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
golang.org/x/term v0.43.0 h1:S4RLU2sB31O/NCl+zFN9Aru9A/Cq2aqKpTZJ6B+DwT4=
golang.org/x/term v0.43.0/go.mod h1:lrhlHNdQJHO+1qVYiHfFKVuVioJIheAc3fBSMFYEIsk=
golang.org/x/time v0.1.0 h1:xYY+Bajn2a7VBmTM5GikTmnK8ZuX8YgnQCqZpbBNtmA=
golang.org/x/time v0.1.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
Expand Down
11 changes: 4 additions & 7 deletions plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ func (s *Plugin) Init(cfg Configurer, log Logger) error {
return errors.E(op, err)
}

var WholeCfg any
err = cfg.Unmarshal(&WholeCfg)
var wholeCfg any
err = cfg.Unmarshal(&wholeCfg)
if err != nil {
return errors.E(op, err)
}

s.wcfg, err = json.Marshal(WholeCfg)
s.wcfg, err = json.Marshal(wholeCfg)
if err != nil {
return err
}
Expand Down Expand Up @@ -125,10 +125,7 @@ func (s *Plugin) Serve() chan error {
mux.Handle(path, handler)
// derive the gRPC service name from the mount path
// (`/<service>/<Method>` or `/<service>/`)
svc := strings.TrimPrefix(path, "/")
if i := strings.Index(svc, "/"); i >= 0 {
svc = svc[:i]
}
svc, _, _ := strings.Cut(strings.TrimPrefix(path, "/"), "/")
services = append(services, svc)
}

Expand Down
12 changes: 12 additions & 0 deletions tests/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,18 @@ func Test_Config_DialerErrorMethod(t *testing.T) {
assert.Error(t, err)
}

func Test_Config_MultipleSeparators(t *testing.T) {
// A DSN with more than one "://" must be rejected by both Valid and Dialer.
cfg := &rpc.Config{Listen: "tcp://host://6001"}

assert.Error(t, cfg.Valid())

conn, err := cfg.Dialer()
assert.Nil(t, conn)
assert.Error(t, err)
assert.Equal(t, "invalid socket DSN (tcp://:6001, unix://file.sock)", err.Error())
}

func Test_Config_Defaults(t *testing.T) {
c := &rpc.Config{}
c.InitDefaults()
Expand Down
Loading