diff --git a/lib/targets.go b/lib/targets.go index 02d06b4c..752a6de3 100644 --- a/lib/targets.go +++ b/lib/targets.go @@ -116,22 +116,37 @@ type Targeter func(*Target) error // hdr will be merged with the each Target's headers. // func NewJSONTargeter(src io.Reader, body []byte, header http.Header) Targeter { - type decoder struct { - *json.Decoder + type reader struct { + *bufio.Reader sync.Mutex } - dec := decoder{Decoder: json.NewDecoder(src)} + rd := reader{Reader: bufio.NewReader(src)} return func(tgt *Target) (err error) { if tgt == nil { return ErrNilTarget } - dec.Lock() - defer dec.Unlock() + rd.Lock() + defer rd.Unlock() + + var line []byte + for len(line) == 0 { + if line, err = rd.ReadBytes('\n'); err != nil { + break + } + line = bytes.TrimSpace(line) // Skip empty lines + } + + if err != nil { + if err == io.EOF { + err = ErrNoTargets + } + return err + } var t Target - if err = dec.Decode(&t); err != nil && err != io.EOF { + if err = json.Unmarshal(line, &t); err != nil { return err } else if t.Method == "" { return ErrNoMethod @@ -157,11 +172,7 @@ func NewJSONTargeter(src io.Reader, body []byte, header http.Header) Targeter { tgt.Header[k] = append(tgt.Header[k], vs...) } - if err == io.EOF { - err = ErrNoTargets - } - - return err + return nil } } diff --git a/lib/targets_test.go b/lib/targets_test.go index 60ca9875..b3527f36 100644 --- a/lib/targets_test.go +++ b/lib/targets_test.go @@ -59,6 +59,10 @@ func TestTargetRequest(t *testing.T) { } func TestJSONTargeter(t *testing.T) { + target := func(s string) io.Reader { + return strings.NewReader(s + "\n") + } + for _, tc := range []struct { name string src io.Reader @@ -80,56 +84,73 @@ func TestJSONTargeter(t *testing.T) { src: &bytes.Buffer{}, in: &Target{}, out: &Target{}, - err: ErrNoMethod, + err: ErrNoTargets, + }, + { + name: "no new line", + src: strings.NewReader(`{"method": "GET", "url": "https://goku"}`), + in: &Target{}, + out: &Target{}, + err: ErrNoTargets, }, { name: "empty object", - src: strings.NewReader(`{}`), + src: target("{}"), in: &Target{}, out: &Target{}, err: ErrNoMethod, }, { name: "empty method", - src: strings.NewReader(`{"method": ""}`), + src: target(`{"method": ""}`), in: &Target{}, out: &Target{}, err: ErrNoMethod, }, { name: "empty url", - src: strings.NewReader(`{"method": "GET"}`), + src: target(`{"method": "GET"}`), in: &Target{}, out: &Target{}, err: ErrNoURL, }, { name: "bad body encoding", - src: strings.NewReader(`{"method": "GET", "url": "http://goku", "body": "NOT BASE64"}`), + src: target(`{"method": "GET", "url": "http://goku", "body": "NOT BASE64"}`), in: &Target{}, out: &Target{}, err: errors.New("illegal base64 data at input byte 3"), }, { name: "default body", - src: strings.NewReader(`{"method": "GET", "url": "http://goku"}`), + src: target(`{"method": "GET", "url": "http://goku"}`), body: []byte(`ATTACK!`), in: &Target{}, out: &Target{Method: "GET", URL: "http://goku", Body: []byte("ATTACK!")}, }, { name: "headers merge", - src: strings.NewReader(`{"method": "GET", "url": "http://goku", "header":{"x": ["foo"]}}`), + src: target(`{"method": "GET", "url": "http://goku", "header":{"x": ["foo"]}}`), hdr: http.Header{"x": []string{"bar"}}, in: &Target{Header: http.Header{"y": []string{"baz"}}}, out: &Target{Method: "GET", URL: "http://goku", Header: http.Header{"y": []string{"baz"}, "x": []string{"bar", "foo"}}}, }, { name: "no defaults", - src: strings.NewReader(`{"method": "GET", "url": "http://goku", "header":{"x": ["foo"]}, "body": "QVRUQUNLIQ=="}`), + src: target(`{"method": "GET", "url": "http://goku", "header":{"x": ["foo"]}, "body": "QVRUQUNLIQ=="}`), in: &Target{}, out: &Target{Method: "GET", URL: "http://goku", Header: http.Header{"x": []string{"foo"}}, Body: []byte("ATTACK!")}, }, + { + name: "skips empty lines and surrounding whitespace", + src: strings.NewReader(` + + {"method": "GET", "url": "https://goku"} + + `), + in: &Target{}, + out: &Target{Method: "GET", URL: "https://goku"}, + }, } { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -149,32 +170,87 @@ func TestJSONTargeter(t *testing.T) { } func TestReadAllTargets(t *testing.T) { - t.Parallel() + equal := func(a, b []Target) bool { + if len(a) != len(b) { + return false + } + + for i := range a { + if !a[i].Equal(&b[i]) { + return false + } + } + + return true + } - src := []byte("GET http://:6060/\nHEAD http://:6606/") - want := []Target{ + targets := []Target{ + {Method: "GET", URL: "http://:6060/"}, + {Method: "HEAD", URL: "http://:6606/"}, + } + + for _, tc := range []struct { + name string + in Targeter + out []Target + err error + }{ { - Method: "GET", - URL: "http://:6060/", - Body: []byte("body"), - Header: http.Header{}, + name: "HTTPTargeter/single", + in: NewHTTPTargeter(strings.NewReader(`GET http://:6060/`), nil, nil), + out: targets[:1], }, { - Method: "HEAD", - URL: "http://:6606/", - Body: []byte("body"), - Header: http.Header{}, + name: "HTTPTargeter/many", + in: NewHTTPTargeter(strings.NewReader(` + GET http://:6060/ + HEAD http://:6606/ + `), nil, nil), + out: targets, }, - } + { + name: "JSONTargeter/single", + in: NewJSONTargeter(strings.NewReader(`{"method": "GET", "url": "http://:6060/"}`+"\n"), nil, nil), + out: targets[:1], + }, + { + name: "JSONTargeter/many", + in: NewJSONTargeter(strings.NewReader(` + {"method": "GET", "url": "http://:6060/"} + {"method": "HEAD", "url": "http://:6606/"} + `), nil, nil), + out: targets, + }, + { + name: "no targets", + in: NewHTTPTargeter(strings.NewReader(""), nil, nil), + err: ErrNoTargets, + }, + { + name: "unexpected error", + in: NewJSONTargeter(errReader{err: io.ErrUnexpectedEOF}, nil, nil), + err: io.ErrUnexpectedEOF, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + out, err := ReadAllTargets(tc.in) + if got, want := out, tc.out; !equal(got, want) { + t.Errorf("got targets: %#v, want %#v", got, want) + } - got, err := ReadAllTargets(NewHTTPTargeter(bytes.NewReader(src), []byte("body"), nil)) - if err != nil { - t.Fatalf("error reading all targets: %v", err) + if got, want := fmt.Sprint(err), fmt.Sprint(tc.err); got != want { + t.Errorf("got err %v, want %v", got, want) + } + }) } +} - if !reflect.DeepEqual(got, want) { - t.Fatalf("got: %#v, want: %#v", got, want) - } +type errReader struct{ err error } + +func (e errReader) Read(p []byte) (n int, err error) { + return 0, e.err } func TestNewHTTPTargeter(t *testing.T) { @@ -189,7 +265,7 @@ func TestNewHTTPTargeter(t *testing.T) { GET http://:6060 @238hhqwjhd8hhw3r.txt`, errors.New("bad header"): ` - GET http://:6060 + GET http://:6060 Authorization`, errors.New("bad header"): ` GET http://:6060