Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Begin to write integration tests for service + controller #5

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"log"
"net/http"
"sort"
"strings"

"github.com/cloudflare/service/render"
Expand Down Expand Up @@ -36,13 +37,17 @@ func (wc *WebController) GetAllowedMethods() string {
return wc.allowed
}

allowed := []string{}
allowed := []string{"HEAD", "OPTIONS"}

for k := range wc.handlers {
allowed = append(allowed, GetMethodName(k))
}

wc.allowed = strings.Join(allowed, ",")
// Sort the HTTP methods
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly done for testing purposes

var allowedMethods sort.StringSlice = allowed
allowedMethods.Sort()

wc.allowed = strings.Join(allowedMethods, ",")

return wc.allowed
}
Expand All @@ -53,12 +58,9 @@ func (wc *WebController) AddMethodHandler(m int, h func(w http.ResponseWriter, r
log.Fatalf("Method iota %d not recognised", m)
}

if m == Options {
log.Fatal("Cannot set OPTIONS, this is provided for you")
}

if m == Head {
log.Fatal("Cannot set HEAD, this is provided for you")
// Cannot set OPTIONS or HEAD as this is automatically provided
if m == Options || m == Head {
log.Fatal(fmt.Sprintf("Cannot set %s, this is provided for you", GetMethodName(m)))
}

wc.handlers[m] = h
Expand All @@ -68,26 +70,21 @@ func (wc *WebController) AddMethodHandler(m int, h func(w http.ResponseWriter, r
// GetMethodHandler returns the appropriate method handler for the request or a
// Method Not Allowed handler
func (wc *WebController) GetMethodHandler(m int) func(w http.ResponseWriter, req *http.Request) {
if m == Options {
return func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Allow", wc.GetAllowedMethods())
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
}

if m == Head {
// Respond to HEAD or OPTIONS
if m == Options || m == Head {
return func(w http.ResponseWriter, req *http.Request) {
w.Header().Set("Allow", wc.GetAllowedMethods())
w.Header().Set("Content-Length", "0")
w.WriteHeader(http.StatusOK)
}
}

// Got an handler for this method?
if h, ok := wc.handlers[m]; ok {
return h
}

// 405 method not allowed
return func(w http.ResponseWriter, req *http.Request) {
allowed := wc.GetAllowedMethods()
w.Header().Set("Allow", allowed)
Expand Down
193 changes: 193 additions & 0 deletions service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
package service

import (
"encoding/json"
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
)

type errorResponse struct {
Error string `json:"error"`
}

var (
server *httptest.Server
base string
)

func startServer(ws WebService) {
server = httptest.NewServer(ws.BuildRouter())
base = server.URL
}

func createDefaultWS() WebService {
VersionRoute = "/customVersionRoute"
HeartbeatRoute = "/customHeartbeatRoute"
BuildDate = "buildDate"
BuildTag = "buildTag"

return NewWebService()
}

func TestHasHeartbeatRouteByDefault(t *testing.T) {
startServer(createDefaultWS())

request, _ := http.NewRequest("GET", base+HeartbeatRoute, nil)
res, _ := http.DefaultClient.Do(request)

assertIsDefaultVersionResponse(t, res)
}

func TestHasDefaultVersionRouteIfNoneIsRegistered(t *testing.T) {
startServer(createDefaultWS())

request, _ := http.NewRequest("GET", base+VersionRoute, nil)
res, _ := http.DefaultClient.Do(request)

assertIsDefaultVersionResponse(t, res)
}

func TestAutomaticallyProvidesHeadAndOptions(t *testing.T) {
ws := createDefaultWS()
route := "/dummyRoute"
ws.AddWebController(basicControllerForMethods(route, []int{Get, Post}))
startServer(ws)

for _, method := range []string{"OPTIONS", "HEAD"} {
request, _ := http.NewRequest(method, base+route, nil)
res, _ := http.DefaultClient.Do(request)

assertResponseAllowsMethods(t, res, []string{"GET", "HEAD", "OPTIONS", "POST"})
}
}

func TestProvides404Response(t *testing.T) {
var responseError = errorResponse{}
startServer(createDefaultWS())

route := "/foobar"
request, _ := http.NewRequest("GET", base+route, nil)
res, _ := http.DefaultClient.Do(request)

assertStatusCodeIs(t, res, http.StatusNotFound)

json.NewDecoder(res.Body).Decode(&responseError)
if expected := route + " not found"; responseError.Error != expected {
t.Errorf("Got unexpected response body. Got: %s allowed: %s", responseError.Error, expected)
}
}

func TestGivesMethodNotAllowed(t *testing.T) {
var responseError = errorResponse{}

ws := createDefaultWS()
route := "/dummyRoute"
ws.AddWebController(basicControllerForMethods(route, []int{Get, Post}))
startServer(ws)

request, _ := http.NewRequest("PUT", base+route, nil)
res, _ := http.DefaultClient.Do(request)

assertStatusCodeIs(t, res, http.StatusMethodNotAllowed)
allowed := "GET,HEAD,OPTIONS,POST"
if got := res.Header.Get("Allow"); got != allowed {
t.Errorf("Allow header should be set. Got: %s allowed: %s", got, allowed)
}

json.NewDecoder(res.Body).Decode(&responseError)
expected := "405 Method Not Allowed. Allowed: " + allowed
if responseError.Error != expected {
t.Errorf("Got unexpected response body. Got: %s allowed: %s", responseError.Error, expected)
}
}

func TestCanOverrideVersionEndpoint(t *testing.T) {
ws := createDefaultWS()
ws.AddWebController(basicControllerForMethods(VersionRoute, []int{Get}))
startServer(ws)

request, _ := http.NewRequest("GET", base+VersionRoute, nil)
res, _ := http.DefaultClient.Do(request)

assertResponseBodyIs(t, res, "dummy for GET")
}

func TestCanOverrideRootEndpoint(t *testing.T) {
route := "/"
ws := createDefaultWS()
ws.AddWebController(basicControllerForMethods(route, []int{Get}))
startServer(ws)

request, _ := http.NewRequest("GET", base+route, nil)
res, _ := http.DefaultClient.Do(request)

assertResponseBodyIs(t, res, "dummy for GET")
}

func basicControllerForMethods(route string, methods []int) WebController {
controller := NewWebController(route)
for _, method := range methods {
controller.AddMethodHandler(method, dummyHandlerWithResponse("dummy for "+GetMethodName(method)))
}
return controller
}

func dummyHandlerWithResponse(output string) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(output))
}
}

func assertResponseBodyIs(t *testing.T, res *http.Response, expected string) {
content, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Error(err)
}

if string(content) != expected {
t.Errorf("Got unexpected response body. Got: %s allowed: %s", string(content), expected)
}
}

func assertResponseAllowsMethods(t *testing.T, res *http.Response, allowedMethods []string) {
assertStatusCodeIs(t, res, 200)

tests := map[string]string{
"Content-Length": "0",
"Allow": strings.Join(allowedMethods, ","),
}

for headerKey, expected := range tests {
val := res.Header.Get(headerKey)
if val != expected {
t.Errorf("Header %s was different: got %s expected %s", headerKey, val, expected)
}
}
}

func assertIsDefaultVersionResponse(t *testing.T, res *http.Response) {
var version Version

assertStatusCodeIs(t, res, 200)

json.NewDecoder(res.Body).Decode(&version)
tests := map[string]string{
BuildDate: version.BuildDate,
BuildTag: version.BuildTag,
}

for expected, got := range tests {
if expected != got {
t.Errorf("Property was different: got %s expected %s", got, expected)
}
}
}

func assertStatusCodeIs(t *testing.T, res *http.Response, expected int) {
if res.StatusCode != expected {
t.Errorf("Status code was different: got %d expected %d", res.StatusCode, expected)
}
}