From efc12bce8c3761542d90a5b408c6b4505ec5d6aa Mon Sep 17 00:00:00 2001 From: qiangxue Date: Fri, 31 Jan 2020 09:42:23 -0500 Subject: [PATCH] initial import --- .github/ISSUE_TEMPLATE/bug_report.md | 30 +++ .github/ISSUE_TEMPLATE/feature_request.md | 20 ++ .github/workflows/build.yml | 62 ++++++ .gitignore | 24 +++ LICENSE | 21 +++ Makefile | 112 +++++++++++ README.md | 220 ++++++++++++++++++++++ cmd/server/Dockerfile | 37 ++++ cmd/server/entrypoint.sh | 19 ++ cmd/server/main.go | 123 ++++++++++++ cmd/server/main_test.go | 40 ++++ config/dev.yml | 0 config/local.yml | 3 + config/prod.yml | 0 config/qa.yml | 0 docker-compose.yml | 29 +++ go.mod | 19 ++ go.sum | 90 +++++++++ internal/album/api.go | 91 +++++++++ internal/album/api_test.go | 41 ++++ internal/album/repository.go | 81 ++++++++ internal/album/repository_test.go | 67 +++++++ internal/album/service.go | 135 +++++++++++++ internal/album/service_test.go | 178 +++++++++++++++++ internal/auth/api.go | 35 ++++ internal/auth/api_test.go | 34 ++++ internal/auth/middleware.go | 67 +++++++ internal/auth/middleware_test.go | 54 ++++++ internal/auth/service.go | 69 +++++++ internal/auth/service_test.go | 39 ++++ internal/config/config.go | 67 +++++++ internal/entity/album.go | 13 ++ internal/entity/id.go | 8 + internal/entity/user.go | 17 ++ internal/errors/middleware.go | 68 +++++++ internal/errors/middleware_test.go | 93 +++++++++ internal/errors/response.go | 106 +++++++++++ internal/errors/response_test.go | 72 +++++++ internal/healthcheck/api.go | 15 ++ internal/healthcheck/api_test.go | 17 ++ internal/test/api.go | 45 +++++ internal/test/db.go | 53 ++++++ internal/test/mock.go | 35 ++++ migrations/20191217202658_init.down.sql | 1 + migrations/20191217202658_init.up.sql | 7 + pkg/accesslog/middleware.go | 34 ++++ pkg/accesslog/middleware_test.go | 24 +++ pkg/dbcontext/db.go | 64 +++++++ pkg/dbcontext/db_test.go | 141 ++++++++++++++ pkg/log/logger.go | 104 ++++++++++ pkg/log/logger_test.go | 83 ++++++++ pkg/pagination/pages.go | 146 ++++++++++++++ pkg/pagination/pages_test.go | 102 ++++++++++ testdata/testdata.sql | 6 + 54 files changed, 3061 insertions(+) create mode 100644 .github/ISSUE_TEMPLATE/bug_report.md create mode 100644 .github/ISSUE_TEMPLATE/feature_request.md create mode 100644 .github/workflows/build.yml create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 cmd/server/Dockerfile create mode 100755 cmd/server/entrypoint.sh create mode 100644 cmd/server/main.go create mode 100644 cmd/server/main_test.go create mode 100644 config/dev.yml create mode 100644 config/local.yml create mode 100644 config/prod.yml create mode 100644 config/qa.yml create mode 100644 docker-compose.yml create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/album/api.go create mode 100644 internal/album/api_test.go create mode 100644 internal/album/repository.go create mode 100644 internal/album/repository_test.go create mode 100644 internal/album/service.go create mode 100644 internal/album/service_test.go create mode 100644 internal/auth/api.go create mode 100644 internal/auth/api_test.go create mode 100644 internal/auth/middleware.go create mode 100644 internal/auth/middleware_test.go create mode 100644 internal/auth/service.go create mode 100644 internal/auth/service_test.go create mode 100644 internal/config/config.go create mode 100644 internal/entity/album.go create mode 100644 internal/entity/id.go create mode 100644 internal/entity/user.go create mode 100644 internal/errors/middleware.go create mode 100644 internal/errors/middleware_test.go create mode 100644 internal/errors/response.go create mode 100644 internal/errors/response_test.go create mode 100644 internal/healthcheck/api.go create mode 100644 internal/healthcheck/api_test.go create mode 100644 internal/test/api.go create mode 100644 internal/test/db.go create mode 100644 internal/test/mock.go create mode 100644 migrations/20191217202658_init.down.sql create mode 100644 migrations/20191217202658_init.up.sql create mode 100644 pkg/accesslog/middleware.go create mode 100644 pkg/accesslog/middleware_test.go create mode 100644 pkg/dbcontext/db.go create mode 100644 pkg/dbcontext/db_test.go create mode 100644 pkg/log/logger.go create mode 100644 pkg/log/logger_test.go create mode 100644 pkg/pagination/pages.go create mode 100644 pkg/pagination/pages_test.go create mode 100644 testdata/testdata.sql diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..15c5028 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,30 @@ +--- +name: Bug report +about: Create a report to help us improve +title: '' +labels: '' +assignees: qiangxue + +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps to reproduce the behavior: +1. +2. +3. +4. + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Screenshots** +If applicable, add screenshots to help explain your problem. + +**Environment (please complete the following information):** + - OS: [e.g. iOS] + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..c92cea6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Suggest an idea for this project +title: '' +labels: '' +assignees: qiangxue + +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..3fb0142 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,62 @@ +name: build +on: [push, pull_request] +jobs: + + build: + name: Build + runs-on: ubuntu-latest + + services: + postgres: + image: postgres:10.8 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: go_restful + ports: + - 5432/tcp + options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 + + steps: + + - name: Set up Go 1.13 + uses: actions/setup-go@v1 + with: + go-version: 1.13 + id: go + + - name: Set up path + run: | + echo "::set-env name=GOPATH::$(go env GOPATH)" + echo "::add-path::$(go env GOPATH)/bin" + shell: bash + + - name: Check out code into the Go module directory + uses: actions/checkout@v1 + + - name: Get dependencies + run: | + go mod download + go mod verify + go get golang.org/x/tools/cmd/cover + go get github.com/mattn/goveralls + go get golang.org/x/lint/golint + + - name: Run go lint + run: make lint + + - name: Build + run: make build + + - name: Test + env: + APP_DSN: postgres://127.0.0.1:${{ job.services.postgres.ports[5432] }}/go_restful?sslmode=disable&user=postgres&password=postgres + run: | + make migrate + make test-cover + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ./coverage-all.out diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..38c7bd0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,24 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Test coverage output +coverage*.* + +# postgres data volume used by postgres server container for testing purpose +testdata/postgres + +# server binary +./server + +# PID file generated to support live reload +.pid diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b2899ff --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019-2020 Qiang Xue + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4613cc1 --- /dev/null +++ b/Makefile @@ -0,0 +1,112 @@ +MODULE = $(shell go list -m) +VERSION ?= $(shell git describe --tags --always --dirty --match=v* 2> /dev/null || echo "1.0.0") +PACKAGES := $(shell go list ./... | grep -v /vendor/) +LDFLAGS := -ldflags "-X main.Version=${VERSION}" + +CONFIG_FILE ?= ./config/local.yml +APP_DSN ?= $(shell sed -n 's/^dsn:[[:space:]]*"\(.*\)"/\1/p' $(CONFIG_FILE)) +MIGRATE := docker run -v $(shell pwd)/migrations:/migrations --network host migrate/migrate -path=/migrations/ -database "$(APP_DSN)" + +PID_FILE := './.pid' +FSWATCH_FILE := './fswatch.cfg' + +.PHONY: default +default: help + +# generate help info from comments: thanks to https://marmelab.com/blog/2016/02/29/auto-documented-makefile.html +.PHONY: help +help: ## help information about make commands + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: test +test: ## run unit tests + @echo "mode: count" > coverage-all.out + @$(foreach pkg,$(PACKAGES), \ + go test -p=1 -cover -covermode=count -coverprofile=coverage.out ${pkg}; \ + tail -n +2 coverage.out >> coverage-all.out;) + +.PHONY: test-cover +test-cover: test ## run unit tests and show test coverage information + go tool cover -html=coverage-all.out + +.PHONY: run +run: ## run the API server + go run ${LDFLAGS} cmd/server/main.go & echo $$! > $(PID_FILE) + +.PHONY: run-stop +run-stop: ## stop the API server + @pkill -P `cat $(PID_FILE)` || true + +.PHONY: run-restart +run-restart: ## restart the API server + @make run-stop + @printf '%*s\n' "80" '' | tr ' ' - + @echo "Source file changed. Restarting server..." + @make run + @printf '%*s\n' "80" '' | tr ' ' - + +run-live: run ## run the API server with live reload support (requires fswatch) + @fswatch -x -o --event Created --event Updated --event Renamed -r internal pkg cmd config | xargs -n1 -I {} make run-restart + +.PHONY: build +build: ## build the API server binary + CGO_ENABLED=0 go build ${LDFLAGS} -a -o server $(MODULE)/cmd/server + +.PHONY: build-docker +build-docker: ## build the API server as a docker image + docker build -f cmd/server/Dockerfile -t server . + +.PHONY: clean +clean: ## remove temporary files + rm -rf server coverage.out coverage-all.out + +.PHONY: version +version: ## display the version of the API server + @echo $(VERSION) + +.PHONY: db-start +db-start: ## start the database server + @mkdir -p testdata/postgres + docker run --rm --name postgres -v $(shell pwd)/testdata:/testdata \ + -v $(shell pwd)/testdata/postgres:/var/lib/postgresql/data \ + -e POSTGRES_PASSWORD=postgres -e POSTGRES_DB=go_restful -d -p 5432:5432 postgres + +.PHONY: db-stop +db-stop: ## stop the database server + docker stop postgres + +.PHONY: testdata +testdata: ## populate the database with test data + make migrate-reset + @echo "Populating test data..." + @docker exec -it postgres psql "$(APP_DSN)" -f /testdata/testdata.sql + +.PHONY: lint +lint: ## run golint on all Go package + @golint $(PACKAGES) + +.PHONY: fmt +fmt: ## run "go fmt" on all Go packages + @go fmt $(PACKAGES) + +.PHONY: migrate +migrate: ## run all new database migrations + @echo "Running all new database migrations..." + @$(MIGRATE) up + +.PHONY: migrate-down +migrate-down: ## revert database to the last migration step + @echo "Reverting database to the last migration step..." + @$(MIGRATE) down 1 + +.PHONY: migrate-new +migrate-new: ## create a new database migration + @read -p "Enter the name of the new migration: " name; \ + $(MIGRATE) create -ext sql -dir /migrations/ $${name// /_} + +.PHONY: migrate-reset +migrate-reset: ## reset database and re-run all migrations + @echo "Resetting database..." + @$(MIGRATE) drop + @echo "Running all database migrations..." + @$(MIGRATE) up diff --git a/README.md b/README.md new file mode 100644 index 0000000..0c14a62 --- /dev/null +++ b/README.md @@ -0,0 +1,220 @@ +# Go RESTful API Starter Kit (Boilerplate) + +[![GoDoc](https://godoc.org/github.com/qiangxue/go-rest-api?status.png)](http://godoc.org/github.com/qiangxue/go-rest-api) +[![Build Status](https://github.com/qiangxue/go-rest-api/workflows/build/badge.svg)](https://github.com/qiangxue/go-rest-api/actions?query=workflow%3Abuild) +[![Code Coverage](https://codecov.io/gh/qiangxue/go-rest-api/branch/master/graph/badge.svg)](https://codecov.io/gh/qiangxue/go-rest-api) +[![Go Report](https://goreportcard.com/badge/github.com/qiangxue/go-rest-api)](https://goreportcard.com/report/github.com/qiangxue/go-rest-api) + +This starter kit is designed to get you up and running with a project structure optimized for developing +RESTful API services in Go. It promotes the best practices that follow the [SOLID principles](https://en.wikipedia.org/wiki/SOLID) +and [clean architecture](https://blog.cleancoder.com/uncle-bob/2012/08/13/the-clean-architecture.html). +It encourages writing clean and idiomatic Go code. + +The kit provides the following features right out of the box: + +* RESTful endpoints in the widely accepted format +* Standard CRUD operations of a database table +* JWT-based authentication +* Environment dependent application configuration management +* Structured logging with contextual information +* Error handling with proper error response generation +* Database migration +* Data validation +* Full test coverage +* Live reloading during development + +The kit uses the following Go packages which can be easily replaced with your own favorite ones +since their usages are mostly localized and abstracted. + +* Routing: [ozzo-routing](https://github.com/go-ozzo/ozzo-routing) +* Database access: [ozzo-dbx](https://github.com/go-ozzo/ozzo-dbx) +* Database migration: [golang-migrate](https://github.com/golang-migrate/migrate) +* Data validation: [ozzo-validation](https://github.com/go-ozzo/ozzo-validation) +* Logging: [zap](https://github.com/uber-go/zap) +* JWT: [jwt-go](https://github.com/dgrijalva/jwt-go) + +## Getting Started + +If this is your first time encountering Go, please follow [the instructions](https://golang.org/doc/install) to +install Go on your computer. The kit requires **Go 1.13 or above**. + +[Docker](https://www.docker.com/get-started) is also needed if you want to try the kit without setting up your +own database server. The kit requires **Docker 17.05 or higher** for the multi-stage build support. + +After installing Go and Docker, run the following commands to start experiencing this starter kit: + +```shell +# download the starter kit +git clone https://github.com/qiangxue/go-rest-api.git + +cd go-rest-api + +# start a PostgreSQL database server in a Docker container +make db-start + +# seed the database with some test data +make testdata + +# run the RESTful API server +make run + +# or run the API server with live reloading, which is useful during development +# requires fswatch (https://github.com/emcrisostomo/fswatch) +make run-live +``` + +At this time, you have a RESTful API server running at `http://127.0.0.1:8080`. It provides the following endpoints: + +* `GET /healthcheck`: a healthcheck service provided for health checking purpose (needed when implementing a server cluster) +* `POST /v1/login`: authenticates a user and generates a JWT +* `GET /v1/albums`: returns a paginated list of the albums +* `GET /v1/albums/:id`: returns the detailed information of an album +* `POST /v1/albums`: creates a new album +* `PUT /v1/albums/:id`: updates an existing album +* `DELETE /v1/albums/:id`: deletes an album + +Try the URL `http://localhost:8080/healthcheck` in a browser, and you should see something like `"OK v1.0.0"` displayed. + +If you have `cURL` or some API client tools (e.g. [Postman](https://www.getpostman.com/)), you may try the following +more complex scenarios: + +```shell +# authenticate the user via: POST /v1/login +curl -X POST -H "Content-Type: application/json" -d '{"username": "demo", "password": "pass"}' http://localhost:8080/v1/login +# should return a JWT token like: {"token":"...JWT token here..."} + +# with the above JWT token, access the album resources, such as: GET /v1/albums +curl -X GET -H "Authorization: Bearer ...JWT token here..." http://localhost:8080/v1/albums +# should return a list of album records in the JSON format +``` + +To use the starter kit as a starting point of a real project whose package name is `github.com/abc/xyz`, do a global +replacement of the string `github.com/qiangxue/go-rest-api` in all of project files with the string `github.com/abc/xyz`. + + +## Project Layout + +The starter kit uses the following project layout: + +``` +. +├── cmd main applications of the project +│   └── server the API server application +├── config configuration files for different environments +├── internal private application and library code +│   ├── album album-related features +│   ├── auth authentication feature +│   ├── config configuration library +│   ├── entity entity definitions and domain logic +│   ├── errors error types and handling +│   ├── healthcheck healthcheck feature +│   └── test helpers for testing purpose +├── migrations database migrations +├── pkg public library code +│   ├── accesslog access log middleware +│   ├── graceful graceful shutdown of HTTP server +│   ├── log structured and context-aware logger +│   └── pagination paginated list +└── testdata test data scripts +``` + +The top level directories `cmd`, `internal`, `pkg` are commonly found in other popular Go projects, as explained in +[Standard Go Project Layout](https://github.com/golang-standards/project-layout). + +Within `internal` and `pkg`, packages are structured by features in order to achieve the so-called +[screaming architecture](https://blog.cleancoder.com/uncle-bob/2011/09/30/Screaming-Architecture.html). For example, +the `album` directory contains the application logic related with the album feature. + +Within each feature package, code are organized in layers (API, service, repository), following the dependency guidelines +as described in the [clean architecture](https://blog.cleancoder.com/uncle-bob/2012/08/13/the-clean-architecture.html). + + +## Common Development Tasks + +This section describes some common development tasks using this starter kit. + +### Implementing a New Feature + +Implementing a new feature typically involves the following steps: + +1. Develop the service that implements the business logic supporting the feature. Please refer to `internal/album/service.go` as an example. +2. Develop the RESTful API exposing the service about the feature. Please refer to `internal/album/api.go` as an example. +3. Develop the repository that persists the data entities needed by the service. Please refer to `internal/album/repsitory.go` as an example. +4. Wire up the above components together by injecting their dependencies in the main function. Please refer to + the `album.RegisterHandlers()` call in `cmd/server/main.go`. + +### Working with DB Transactions + +It is the responsibility of the service layer to determine whether DB operations should be enclosed in a transaction. +The DB operations implemented by the repository layer should work both with and without a transaction. + +You can use `dbcontext.DB.Transactional()` in a service method to enclose multiple repository method calls in +a transaction. For example, + +```go +func serviceMethod(ctx context.Context, repo Repository, transactional dbcontext.TransactionFunc) error { + return transactional(ctx, func(ctx context.Context) error { + repo.method1(...) + repo.method2(...) + return nil + }) +} +``` + +If needed, you can also enclose method calls of different repositories in a single transaction. The return value +of the function in `transactional` above determines if the transaction should be committed or rolled back. + +You can also use `dbcontext.DB.TransactionHandler()` as a middleware to enclose a whole API handler in a transaction. +This is especially useful if an API handler needs to put method calls of multiple services in a transaction. + + +### Updating Database Schema + +The starter kit uses [database migration](https://en.wikipedia.org/wiki/Schema_migration) to manage the changes of the +database schema over the whole project development phase. The following commands are commonly used with regard to database +schema changes: + +```shell +# Execute new migrations made by you or other team members. +# Usually you should run this command each time after you pull new code from the code repo. +make migrate + +# Create a new database migration. +# In the generated `migrations/*.up.sql` file, write the SQL statements that implement the schema changes. +# In the `*.down.sql` file, write the SQL statements that revert the schema changes. +make migrate-new + +# Revert the last database migration. +# This is often used when a migration has some issues and needs to be reverted. +make migrate-down + +# Clean up the database and rerun the migrations from the very beginning. +# Note that this command will first erase all data and tables in the database, and then +# run all migrations. +make migrate-reset +``` + +### Managing Configurations + +The application configuration is represented in `internal/config/config.go`. When the application starts, +it loads the configuration from a configuration file as well as environment variables. The path to the configuration +file is specified via the `-config` command line argument which defaults to `./config/local.yml`. Configurations +specified in environment variables should be named with the `APP_` prefix and in upper case. When a configuration +is specified in both a configuration file and an environment variable, the latter takes precedence. + +The `config` directory contains the configuration files named after different environments. For example, +`config/local.yml` corresponds to the local development environment and is used when running the application +via `make run`. + +Do not keep secrets in the configuration files. Provide them via environment variables instead. For example, +you should provide `Config.DSN` using the `APP_DSN` environment variable. Secrets can be populated from a secret +storage (e.g. HashiCorp Vault) into environment variables in a bootstrap script (e.g. `cmd/server/entryscript.sh`). + +## Deployment + +The application can be run as a docker container. You can use `make build-docker` to build the application +into a docker image. + +The docker container starts with the `cmd/server/entryscript.sh` script which reads the `APP_ENV` environment +variable to determine which configuration file to use. For example, if `APP_ENV` is `qa`, the application will +be started with the `config/qa.yml` configuration file. diff --git a/cmd/server/Dockerfile b/cmd/server/Dockerfile new file mode 100644 index 0000000..eac2769 --- /dev/null +++ b/cmd/server/Dockerfile @@ -0,0 +1,37 @@ +FROM golang:alpine AS build +RUN apk update && \ + apk add curl \ + git \ + bash \ + make \ + ca-certificates && \ + rm -rf /var/cache/apk/* + +# install migrate which will be used by entrypoint.sh to perform DB migration +ARG MIGRATE_VERSION=4.7.1 +ADD https://github.com/golang-migrate/migrate/releases/download/v${MIGRATE_VERSION}/migrate.linux-amd64.tar.gz /tmp +RUN tar -xzf /tmp/migrate.linux-amd64.tar.gz -C /usr/local/bin && mv /usr/local/bin/migrate.linux-amd64 /usr/local/bin/migrate + +WORKDIR /app + +# copy module files first so that they don't need to be downloaded again if no change +COPY go.* ./ +RUN go mod download +RUN go mod verify + +# copy source files and build the binary +COPY . . +RUN make build + + +FROM alpine:latest +RUN apk --no-cache add ca-certificates bash +RUN mkdir -p /var/log/app +WORKDIR /app/ +COPY --from=build /usr/local/bin/migrate /usr/local/bin +COPY --from=build /app/migrations ./migrations/ +COPY --from=build /app/server . +COPY --from=build /app/cmd/server/entrypoint.sh . +COPY --from=build /app/config/*.yml ./config/ +RUN ls -la +ENTRYPOINT ["./entrypoint.sh"] diff --git a/cmd/server/entrypoint.sh b/cmd/server/entrypoint.sh new file mode 100755 index 0000000..1119734 --- /dev/null +++ b/cmd/server/entrypoint.sh @@ -0,0 +1,19 @@ +#!/bin/bash -e + +exec > >(tee -a /var/log/app/entry.log|logger -t server -s 2>/dev/console) 2>&1 + +APP_ENV=${APP_ENV:-local} + +echo "[`date`] Running entrypoint script in the '${APP_ENV}' environment..." + +CONFIG_FILE=./config/${APP_ENV}.yml + +if [[ -z ${APP_DSN} ]]; then + export APP_DSN=`sed -n 's/^dsn:[[:space:]]*"\(.*\)"/\1/p' ${CONFIG_FILE}` +fi + +echo "[`date`] Running DB migrations..." +migrate -database "${APP_DSN}" -path ./migrations up + +echo "[`date`] Starting server..." +./server -config ${CONFIG_FILE} >> /var/log/app/server.log 2>&1 diff --git a/cmd/server/main.go b/cmd/server/main.go new file mode 100644 index 0000000..b71d554 --- /dev/null +++ b/cmd/server/main.go @@ -0,0 +1,123 @@ +package main + +import ( + "context" + "database/sql" + "flag" + "fmt" + "github.com/go-ozzo/ozzo-dbx" + "github.com/go-ozzo/ozzo-routing/v2" + "github.com/go-ozzo/ozzo-routing/v2/content" + "github.com/go-ozzo/ozzo-routing/v2/cors" + _ "github.com/lib/pq" + "github.com/qiangxue/go-rest-api/internal/album" + "github.com/qiangxue/go-rest-api/internal/auth" + "github.com/qiangxue/go-rest-api/internal/config" + "github.com/qiangxue/go-rest-api/internal/errors" + "github.com/qiangxue/go-rest-api/internal/healthcheck" + "github.com/qiangxue/go-rest-api/pkg/accesslog" + "github.com/qiangxue/go-rest-api/pkg/dbcontext" + "github.com/qiangxue/go-rest-api/pkg/log" + "net/http" + "os" + "time" +) + +// Version indicates the current version of the application. +var Version = "1.0.0" + +var flagConfig = flag.String("config", "./config/local.yml", "path to the config file") + +func main() { + flag.Parse() + // create root logger tagged with server version + logger := log.New().With(nil, "version", Version) + + // load application configurations + cfg, err := config.Load(*flagConfig, logger) + if err != nil { + logger.Errorf("failed to load application configuration: %s", err) + os.Exit(-1) + } + + // connect to the database + db, err := dbx.MustOpen("postgres", cfg.DSN) + if err != nil { + logger.Error(err) + os.Exit(-1) + } + db.QueryLogFunc = logDBQuery(logger) + db.ExecLogFunc = logDBExec(logger) + defer func() { + if err := db.Close(); err != nil { + logger.Error(err) + } + }() + + // build HTTP server + address := fmt.Sprintf(":%v", cfg.ServerPort) + hs := &http.Server{ + Addr: address, + Handler: buildHandler(logger, dbcontext.New(db), cfg), + } + + // start the HTTP server with graceful shutdown + go routing.GracefulShutdown(hs, 10*time.Second, logger.Infof) + logger.Infof("server %v is running at %v", Version, address) + if err := hs.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.Error(err) + os.Exit(-1) + } +} + +// buildHandler sets up the HTTP routing and builds an HTTP handler. +func buildHandler(logger log.Logger, db *dbcontext.DB, cfg *config.Config) http.Handler { + router := routing.New() + + router.Use( + accesslog.Handler(logger), + errors.Handler(logger), + content.TypeNegotiator(content.JSON), + cors.Handler(cors.AllowAll), + ) + + healthcheck.RegisterHandlers(router, Version) + + rg := router.Group("/v1") + + authHandler := auth.Handler(cfg.JWTVerificationKey) + + album.RegisterHandlers(rg.Group(""), + album.NewService(album.NewRepository(db, logger), logger), + authHandler, logger, + ) + + auth.RegisterHandlers(rg.Group(""), + auth.NewService(cfg.JWTSigningKey, cfg.JWTExpiration, logger), + logger, + ) + + return router +} + +// logDBQuery returns a logging function that can be used to log SQL queries. +func logDBQuery(logger log.Logger) dbx.QueryLogFunc { + return func(ctx context.Context, t time.Duration, sql string, rows *sql.Rows, err error) { + if err == nil { + logger.With(ctx, "duration", t.Milliseconds(), "sql", sql).Info("DB query successful") + } else { + logger.With(ctx, "sql", sql).Errorf("DB query error: %v", err) + } + } +} + +// logDBExec returns a logging function that can be used to log SQL executions. +func logDBExec(logger log.Logger) dbx.ExecLogFunc { + return func(ctx context.Context, t time.Duration, sql string, result sql.Result, err error) { + if err == nil { + logger.With(ctx, "duration", t.Milliseconds(), "sql", sql).Info("DB execution successful") + } else { + logger.With(ctx, "sql", sql).Errorf("DB execution error: %v", err) + } + } +} diff --git a/cmd/server/main_test.go b/cmd/server/main_test.go new file mode 100644 index 0000000..2a974f6 --- /dev/null +++ b/cmd/server/main_test.go @@ -0,0 +1,40 @@ +package main + +import ( + "context" + "fmt" + "github.com/qiangxue/go-rest-api/pkg/log" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func Test_logDBQuery(t *testing.T) { + logger, entries := log.NewForTest() + f := logDBQuery(logger) + f(context.Background(), time.Millisecond*3, "sql", nil, nil) + if assert.Equal(t, 1, entries.Len()) { + assert.Equal(t, "DB query successful", entries.All()[0].Message) + } + entries.TakeAll() + + f(context.Background(), time.Millisecond*3, "sql", nil, fmt.Errorf("test")) + if assert.Equal(t, 1, entries.Len()) { + assert.Equal(t, "DB query error: test", entries.All()[0].Message) + } +} + +func Test_logDBExec(t *testing.T) { + logger, entries := log.NewForTest() + f := logDBExec(logger) + f(context.Background(), time.Millisecond*3, "sql", nil, nil) + if assert.Equal(t, 1, entries.Len()) { + assert.Equal(t, "DB execution successful", entries.All()[0].Message) + } + entries.TakeAll() + + f(context.Background(), time.Millisecond*3, "sql", nil, fmt.Errorf("test")) + if assert.Equal(t, 1, entries.Len()) { + assert.Equal(t, "DB execution error: test", entries.All()[0].Message) + } +} diff --git a/config/dev.yml b/config/dev.yml new file mode 100644 index 0000000..e69de29 diff --git a/config/local.yml b/config/local.yml new file mode 100644 index 0000000..3485bc9 --- /dev/null +++ b/config/local.yml @@ -0,0 +1,3 @@ +dsn: "postgres://127.0.0.1/go_restful?sslmode=disable&user=postgres&password=postgres" +jwt_signing_key: "LxsKJywDL5O5PvgODZhBH12KE6k2yL8E" +jwt_verification_key: "IuDQoh1QAIFwnKAydSntyJrTmOkct3wN" diff --git a/config/prod.yml b/config/prod.yml new file mode 100644 index 0000000..e69de29 diff --git a/config/qa.yml b/config/qa.yml new file mode 100644 index 0000000..e69de29 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..faabc75 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,29 @@ +version: '2.1' +services: + server: + image: server + build: + context: . + dockerfile: cmd/server/Dockerfile + volumes: + - /tmp/app:/var/log/app + ports: + - "8080:8080" + environment: + - APP_ENV=local + - APP_DSN=postgres://db/go_restful?sslmode=disable&user=postgres&password=postgres + depends_on: + db: + condition: service_healthy + db: + image: "postgres:alpine" + restart: always + environment: + POSTGRES_USER: "postgres" + POSTGRES_PASSWORD: "postgres" + POSTGRES_DB: "go_restful" + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres"] + interval: 10s + timeout: 5s + retries: 5 diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..41b8a40 --- /dev/null +++ b/go.mod @@ -0,0 +1,19 @@ +module github.com/qiangxue/go-rest-api + +go 1.13 + +require ( + github.com/dgrijalva/jwt-go v3.2.0+incompatible + github.com/go-ozzo/ozzo-dbx v1.5.0 + github.com/go-ozzo/ozzo-routing/v2 v2.3.0 + github.com/go-ozzo/ozzo-validation/v4 v4.0.0 + github.com/google/uuid v1.1.1 + github.com/lib/pq v1.2.0 + github.com/qiangxue/go-env v1.0.0 + github.com/stretchr/testify v1.4.0 + go.uber.org/atomic v1.5.1 // indirect + go.uber.org/multierr v1.4.0 // indirect + go.uber.org/zap v1.13.0 + golang.org/x/tools v0.0.0-20191209225234-22774f7dae43 // indirect + gopkg.in/yaml.v2 v2.2.2 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..8250e4f --- /dev/null +++ b/go.sum @@ -0,0 +1,90 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= +github.com/go-ozzo/ozzo-dbx v1.5.0 h1:QPJOdFDKoJYlDLN7QczZ+uYUoIQD5gaiCvytCUMtSoE= +github.com/go-ozzo/ozzo-dbx v1.5.0/go.mod h1:ohIonWn3ed1mSYxvb5NTkaEjN4c52hbs8HI256FJhB8= +github.com/go-ozzo/ozzo-routing/v2 v2.3.0 h1:UtDziUJR20kj81xQU1IMDiDfUxcH1RNrU0rnaZCjtu4= +github.com/go-ozzo/ozzo-routing/v2 v2.3.0/go.mod h1:7gOQKWsVmMMEyAF2TnVrl1BtBv6XKY2UtmFJdC/krE8= +github.com/go-ozzo/ozzo-validation/v4 v4.0.0 h1:tZtJsQQJOBh6Cl5ni10Xqv53LzO+8A7iI0/X5ap9TnE= +github.com/go-ozzo/ozzo-validation/v4 v4.0.0/go.mod h1:cQmT+ki0c76Pk/pd0QohBsQ6BcqjeMM7Nkxi/kEdzAA= +github.com/go-sql-driver/mysql v1.4.1 h1:g24URVg0OFbNUTx9qqY1IRZ9D9z3iPyi5zKhQZpNwpA= +github.com/go-sql-driver/mysql v1.4.1/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/golang/gddo v0.0.0-20190904175337-72a348e765d2 h1:xisWqjiKEff2B0KfFYGpCqc3M3zdTz+OHQHRc09FeYk= +github.com/golang/gddo v0.0.0-20190904175337-72a348e765d2/go.mod h1:xEhNfoBDX1hzLm2Nf80qUvZ2sVwoMZ8d6IE2SrsQfh4= +github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/go-cmp v0.3.1 h1:Xye71clBPdm5HgqGwUkwhbynsUJZhDbS20FvLhQ2izg= +github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.2.0 h1:LXpIM/LZ5xGFhOpXAQUIMM1HdyqzVYM13zNdjCEEcA0= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/qiangxue/go-env v1.0.0 h1:WllJh3I59gq2Ekgf5mtSfhqtQcssVLfNKsZ2GgyoVsY= +github.com/qiangxue/go-env v1.0.0/go.mod h1:289F52HNQ7gxpmBgOqRVzV6onYxAdJrnjcylzJfY1NM= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.5.1 h1:rsqfU5vBkVknbhUGbAUwQKR2H4ItV8tjJ+6kJX4cxHM= +go.uber.org/atomic v1.5.1/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.4.0 h1:f3WCSC2KzAcBXGATIxAB1E2XuCpNU255wNKZ505qi3E= +go.uber.org/multierr v1.4.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee h1:0mgffUl7nfd+FpvXMVz4IDEaUSmT1ysygQC7qYo7sG4= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.13.0 h1:nR6NoDBgAf67s68NhaXbsojM+2gxp3S1hWkHDl27pVU= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f h1:J5lckAjkw6qYlOZNj90mLYNTEKDvWeuc1yieZ8qUzUE= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191205133340-d1f10d1c4e25/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191209225234-22774f7dae43 h1:NfPq5mgc5ArFgVLCpeS4z07IoxSAqVfV/gQ5vxdgaxI= +golang.org/x/tools v0.0.0-20191209225234-22774f7dae43/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM= +google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= +gopkg.in/asaskevich/govalidator.v9 v9.0.0-20180315120708-ccb8e960c48f h1:RVvpqSdNKxt6sENjmw0kdyyv8r18TdpmYTrvUUg2qkc= +gopkg.in/asaskevich/govalidator.v9 v9.0.0-20180315120708-ccb8e960c48f/go.mod h1:+MTrBL6wlsxv1uFXT6b9LWG7PJdrvUJEjl8tXOlk9OU= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/internal/album/api.go b/internal/album/api.go new file mode 100644 index 0000000..9ddb060 --- /dev/null +++ b/internal/album/api.go @@ -0,0 +1,91 @@ +package album + +import ( + "github.com/go-ozzo/ozzo-routing/v2" + "github.com/qiangxue/go-rest-api/internal/errors" + "github.com/qiangxue/go-rest-api/pkg/log" + "github.com/qiangxue/go-rest-api/pkg/pagination" + "net/http" +) + +// RegisterHandlers sets up the routing of the HTTP handlers. +func RegisterHandlers(r *routing.RouteGroup, service Service, authHandler routing.Handler, logger log.Logger) { + res := resource{service, logger} + + r.Get("/albums/", res.get) + r.Get("/albums", res.query) + + r.Use(authHandler) + + // the following endpoints require a valid JWT + r.Post("/albums", res.create) + r.Put("/albums/", res.update) + r.Delete("/albums/", res.delete) +} + +type resource struct { + service Service + logger log.Logger +} + +func (r resource) get(c *routing.Context) error { + album, err := r.service.Get(c.Request.Context(), c.Param("id")) + if err != nil { + return err + } + + return c.Write(album) +} + +func (r resource) query(c *routing.Context) error { + ctx := c.Request.Context() + count, err := r.service.Count(ctx) + if err != nil { + return err + } + pages := pagination.NewFromRequest(c.Request, count) + albums, err := r.service.Query(ctx, pages.Offset(), pages.Limit()) + if err != nil { + return err + } + pages.Items = albums + return c.Write(pages) +} + +func (r resource) create(c *routing.Context) error { + var input CreateAlbumRequest + if err := c.Read(&input); err != nil { + r.logger.With(c.Request.Context()).Info(err) + return errors.BadRequest("") + } + album, err := r.service.Create(c.Request.Context(), input) + if err != nil { + return err + } + + return c.WriteWithStatus(album, http.StatusCreated) +} + +func (r resource) update(c *routing.Context) error { + var input UpdateAlbumRequest + if err := c.Read(&input); err != nil { + r.logger.With(c.Request.Context()).Info(err) + return errors.BadRequest("") + } + + album, err := r.service.Update(c.Request.Context(), c.Param("id"), input) + if err != nil { + return err + } + + return c.Write(album) +} + +func (r resource) delete(c *routing.Context) error { + album, err := r.service.Delete(c.Request.Context(), c.Param("id")) + if err != nil { + return err + } + + return c.Write(album) +} diff --git a/internal/album/api_test.go b/internal/album/api_test.go new file mode 100644 index 0000000..3672749 --- /dev/null +++ b/internal/album/api_test.go @@ -0,0 +1,41 @@ +package album + +import ( + "github.com/qiangxue/go-rest-api/internal/auth" + "github.com/qiangxue/go-rest-api/internal/entity" + "github.com/qiangxue/go-rest-api/internal/test" + "github.com/qiangxue/go-rest-api/pkg/log" + "net/http" + "testing" + "time" +) + +func TestAPI(t *testing.T) { + logger, _ := log.NewForTest() + router := test.MockRouter(logger) + repo := &mockRepository{items: []entity.Album{ + {"123", "album123", time.Now(), time.Now()}, + }} + RegisterHandlers(router.Group(""), NewService(repo, logger), auth.MockAuthHandler, logger) + header := auth.MockAuthHeader() + + tests := []test.APITestCase{ + {"get all", "GET", "/albums", "", nil, http.StatusOK, `*"total_count":1*`}, + {"get 123", "GET", "/albums/123", "", nil, http.StatusOK, `*album123*`}, + {"get unknown", "GET", "/albums/1234", "", nil, http.StatusNotFound, ""}, + {"create ok", "POST", "/albums", `{"name":"test"}`, header, http.StatusCreated, "*test*"}, + {"create ok count", "GET", "/albums", "", nil, http.StatusOK, `*"total_count":2*`}, + {"create auth error", "POST", "/albums", `{"name":"test"}`, nil, http.StatusUnauthorized, ""}, + {"create input error", "POST", "/albums", `"name":"test"}`, header, http.StatusBadRequest, ""}, + {"update ok", "PUT", "/albums/123", `{"name":"albumxyz"}`, header, http.StatusOK, "*albumxyz*"}, + {"update verify", "GET", "/albums/123", "", nil, http.StatusOK, `*albumxyz*`}, + {"update auth error", "PUT", "/albums/123", `{"name":"albumxyz"}`, nil, http.StatusUnauthorized, ""}, + {"update input error", "PUT", "/albums/123", `"name":"albumxyz"}`, header, http.StatusBadRequest, ""}, + {"delete ok", "DELETE", "/albums/123", ``, header, http.StatusOK, "*albumxyz*"}, + {"delete verify", "DELETE", "/albums/123", ``, header, http.StatusNotFound, ""}, + {"delete auth error", "DELETE", "/albums/123", ``, nil, http.StatusUnauthorized, ""}, + } + for _, tc := range tests { + test.Endpoint(t, router, tc) + } +} diff --git a/internal/album/repository.go b/internal/album/repository.go new file mode 100644 index 0000000..44a1e0a --- /dev/null +++ b/internal/album/repository.go @@ -0,0 +1,81 @@ +package album + +import ( + "context" + "github.com/qiangxue/go-rest-api/internal/entity" + "github.com/qiangxue/go-rest-api/pkg/dbcontext" + "github.com/qiangxue/go-rest-api/pkg/log" +) + +// Repository encapsulates the logic to access albums from the data source. +type Repository interface { + // Get returns the album with the specified album ID. + Get(ctx context.Context, id string) (entity.Album, error) + // Count returns the number of albums. + Count(ctx context.Context) (int, error) + // Query returns the list of albums with the given offset and limit. + Query(ctx context.Context, offset, limit int) ([]entity.Album, error) + // Create saves a new album in the storage. + Create(ctx context.Context, album entity.Album) error + // Update updates the album with given ID in the storage. + Update(ctx context.Context, album entity.Album) error + // Delete removes the album with given ID from the storage. + Delete(ctx context.Context, id string) error +} + +// repository persists albums in database +type repository struct { + db *dbcontext.DB + logger log.Logger +} + +// NewRepository creates a new album repository +func NewRepository(db *dbcontext.DB, logger log.Logger) Repository { + return repository{db, logger} +} + +// Get reads the album with the specified ID from the database. +func (r repository) Get(ctx context.Context, id string) (entity.Album, error) { + var album entity.Album + err := r.db.With(ctx).Select().Model(id, &album) + return album, err +} + +// Create saves a new album record in the database. +// It returns the ID of the newly inserted album record. +func (r repository) Create(ctx context.Context, album entity.Album) error { + return r.db.With(ctx).Model(&album).Insert() +} + +// Update saves the changes to an album in the database. +func (r repository) Update(ctx context.Context, album entity.Album) error { + return r.db.With(ctx).Model(&album).Update() +} + +// Delete deletes an album with the specified ID from the database. +func (r repository) Delete(ctx context.Context, id string) error { + album, err := r.Get(ctx, id) + if err != nil { + return err + } + return r.db.With(ctx).Model(&album).Delete() +} + +// Count returns the number of the album records in the database. +func (r repository) Count(ctx context.Context) (int, error) { + var count int + err := r.db.With(ctx).Select("COUNT(*)").From("album").Row(&count) + return count, err +} + +// Query retrieves the album records with the specified offset and limit from the database. +func (r repository) Query(ctx context.Context, offset, limit int) ([]entity.Album, error) { + var albums []entity.Album + err := r.db.With(ctx). + Select(). + OrderBy("id"). + Offset(int64(offset)). + Limit(int64(limit)). + All(&albums) + return albums, err +} diff --git a/internal/album/repository_test.go b/internal/album/repository_test.go new file mode 100644 index 0000000..635b9d8 --- /dev/null +++ b/internal/album/repository_test.go @@ -0,0 +1,67 @@ +package album + +import ( + "context" + "database/sql" + "github.com/qiangxue/go-rest-api/internal/entity" + "github.com/qiangxue/go-rest-api/internal/test" + "github.com/qiangxue/go-rest-api/pkg/log" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestRepository(t *testing.T) { + logger, _ := log.NewForTest() + db := test.DB(t) + test.ResetTables(t, db, "album") + repo := NewRepository(db, logger) + + ctx := context.Background() + + // initial count + count, err := repo.Count(ctx) + assert.Nil(t, err) + + // create + err = repo.Create(ctx, entity.Album{ + ID: "test1", + Name: "album1", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }) + assert.Nil(t, err) + count2, _ := repo.Count(ctx) + assert.Equal(t, 1, count2-count) + + // get + album, err := repo.Get(ctx, "test1") + assert.Nil(t, err) + assert.Equal(t, "album1", album.Name) + _, err = repo.Get(ctx, "test0") + assert.Equal(t, sql.ErrNoRows, err) + + // update + err = repo.Update(ctx, entity.Album{ + ID: "test1", + Name: "album1 updated", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }) + assert.Nil(t, err) + album, _ = repo.Get(ctx, "test1") + assert.Equal(t, "album1 updated", album.Name) + + // query + albums, err := repo.Query(ctx, 0, count2) + assert.Nil(t, err) + assert.Equal(t, count2, len(albums)) + + // delete + err = repo.Delete(ctx, "test1") + assert.Nil(t, err) + _, err = repo.Get(ctx, "test1") + assert.Equal(t, sql.ErrNoRows, err) + err = repo.Delete(ctx, "test1") + assert.Equal(t, sql.ErrNoRows, err) +} diff --git a/internal/album/service.go b/internal/album/service.go new file mode 100644 index 0000000..6717e57 --- /dev/null +++ b/internal/album/service.go @@ -0,0 +1,135 @@ +package album + +import ( + "context" + validation "github.com/go-ozzo/ozzo-validation/v4" + "github.com/qiangxue/go-rest-api/internal/entity" + "github.com/qiangxue/go-rest-api/pkg/log" + "time" +) + +// Service encapsulates usecase logic for albums. +type Service interface { + Get(ctx context.Context, id string) (Album, error) + Query(ctx context.Context, offset, limit int) ([]Album, error) + Count(ctx context.Context) (int, error) + Create(ctx context.Context, input CreateAlbumRequest) (Album, error) + Update(ctx context.Context, id string, input UpdateAlbumRequest) (Album, error) + Delete(ctx context.Context, id string) (Album, error) +} + +// Album represents the data about an album. +type Album struct { + entity.Album +} + +// CreateAlbumRequest represents an album creation request. +type CreateAlbumRequest struct { + Name string `json:"name"` +} + +// Validate validates the CreateAlbumRequest fields. +func (m CreateAlbumRequest) Validate() error { + return validation.ValidateStruct(&m, + validation.Field(&m.Name, validation.Required, validation.Length(0, 128)), + ) +} + +// UpdateAlbumRequest represents an album update request. +type UpdateAlbumRequest struct { + Name string `json:"name"` +} + +// Validate validates the CreateAlbumRequest fields. +func (m UpdateAlbumRequest) Validate() error { + return validation.ValidateStruct(&m, + validation.Field(&m.Name, validation.Required, validation.Length(0, 128)), + ) +} + +type service struct { + repo Repository + logger log.Logger +} + +// NewService creates a new album service. +func NewService(repo Repository, logger log.Logger) Service { + return service{repo, logger} +} + +// Get returns the album with the specified the album ID. +func (s service) Get(ctx context.Context, id string) (Album, error) { + album, err := s.repo.Get(ctx, id) + if err != nil { + return Album{}, err + } + return Album{album}, nil +} + +// Create creates a new album. +func (s service) Create(ctx context.Context, req CreateAlbumRequest) (Album, error) { + if err := req.Validate(); err != nil { + return Album{}, err + } + id := entity.GenerateID() + now := time.Now() + err := s.repo.Create(ctx, entity.Album{ + ID: id, + Name: req.Name, + CreatedAt: now, + UpdatedAt: now, + }) + if err != nil { + return Album{}, err + } + return s.Get(ctx, id) +} + +// Update updates the album with the specified ID. +func (s service) Update(ctx context.Context, id string, req UpdateAlbumRequest) (Album, error) { + if err := req.Validate(); err != nil { + return Album{}, err + } + + album, err := s.Get(ctx, id) + if err != nil { + return album, err + } + album.Name = req.Name + album.UpdatedAt = time.Now() + + if err := s.repo.Update(ctx, album.Album); err != nil { + return album, err + } + return album, nil +} + +// Delete deletes the album with the specified ID. +func (s service) Delete(ctx context.Context, id string) (Album, error) { + album, err := s.Get(ctx, id) + if err != nil { + return Album{}, err + } + if err = s.repo.Delete(ctx, id); err != nil { + return Album{}, err + } + return album, nil +} + +// Count returns the number of albums. +func (s service) Count(ctx context.Context) (int, error) { + return s.repo.Count(ctx) +} + +// Query returns the albums with the specified offset and limit. +func (s service) Query(ctx context.Context, offset, limit int) ([]Album, error) { + items, err := s.repo.Query(ctx, offset, limit) + if err != nil { + return nil, err + } + result := []Album{} + for _, item := range items { + result = append(result, Album{item}) + } + return result, nil +} diff --git a/internal/album/service_test.go b/internal/album/service_test.go new file mode 100644 index 0000000..a557f9f --- /dev/null +++ b/internal/album/service_test.go @@ -0,0 +1,178 @@ +package album + +import ( + "context" + "database/sql" + "errors" + "github.com/qiangxue/go-rest-api/internal/entity" + "github.com/qiangxue/go-rest-api/pkg/log" + "github.com/stretchr/testify/assert" + "testing" +) + +var errCRUD = errors.New("error crud") + +func TestCreateAlbumRequest_Validate(t *testing.T) { + tests := []struct { + name string + model CreateAlbumRequest + wantError bool + }{ + {"success", CreateAlbumRequest{Name: "test"}, false}, + {"required", CreateAlbumRequest{Name: ""}, true}, + {"too long", CreateAlbumRequest{Name: "1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.model.Validate() + assert.Equal(t, tt.wantError, err != nil) + }) + } +} + +func TestUpdateAlbumRequest_Validate(t *testing.T) { + tests := []struct { + name string + model UpdateAlbumRequest + wantError bool + }{ + {"success", UpdateAlbumRequest{Name: "test"}, false}, + {"required", UpdateAlbumRequest{Name: ""}, true}, + {"too long", UpdateAlbumRequest{Name: "1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890"}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.model.Validate() + assert.Equal(t, tt.wantError, err != nil) + }) + } +} + +func Test_service_CRUD(t *testing.T) { + logger, _ := log.NewForTest() + s := NewService(&mockRepository{}, logger) + + ctx := context.Background() + + // initial count + count, _ := s.Count(ctx) + assert.Equal(t, 0, count) + + // successful creation + album, err := s.Create(ctx, CreateAlbumRequest{Name: "test"}) + assert.Nil(t, err) + assert.NotEmpty(t, album.ID) + id := album.ID + assert.Equal(t, "test", album.Name) + assert.NotEmpty(t, album.CreatedAt) + assert.NotEmpty(t, album.UpdatedAt) + count, _ = s.Count(ctx) + assert.Equal(t, 1, count) + + // validation error in creation + _, err = s.Create(ctx, CreateAlbumRequest{Name: ""}) + assert.NotNil(t, err) + count, _ = s.Count(ctx) + assert.Equal(t, 1, count) + + // unexpected error in creation + _, err = s.Create(ctx, CreateAlbumRequest{Name: "error"}) + assert.Equal(t, errCRUD, err) + count, _ = s.Count(ctx) + assert.Equal(t, 1, count) + + _, _ = s.Create(ctx, CreateAlbumRequest{Name: "test2"}) + + // update + album, err = s.Update(ctx, id, UpdateAlbumRequest{Name: "test updated"}) + assert.Nil(t, err) + assert.Equal(t, "test updated", album.Name) + _, err = s.Update(ctx, "none", UpdateAlbumRequest{Name: "test updated"}) + assert.NotNil(t, err) + + // validation error in update + _, err = s.Update(ctx, id, UpdateAlbumRequest{Name: ""}) + assert.NotNil(t, err) + count, _ = s.Count(ctx) + assert.Equal(t, 2, count) + + // unexpected error in update + _, err = s.Update(ctx, id, UpdateAlbumRequest{Name: "error"}) + assert.Equal(t, errCRUD, err) + count, _ = s.Count(ctx) + assert.Equal(t, 2, count) + + // get + _, err = s.Get(ctx, "none") + assert.NotNil(t, err) + album, err = s.Get(ctx, id) + assert.Nil(t, err) + assert.Equal(t, "test updated", album.Name) + assert.Equal(t, id, album.ID) + + // query + albums, _ := s.Query(ctx, 0, 0) + assert.Equal(t, 2, len(albums)) + + // delete + _, err = s.Delete(ctx, "none") + assert.NotNil(t, err) + album, err = s.Delete(ctx, id) + assert.Nil(t, err) + assert.Equal(t, id, album.ID) + count, _ = s.Count(ctx) + assert.Equal(t, 1, count) +} + +type mockRepository struct { + items []entity.Album +} + +func (m mockRepository) Get(ctx context.Context, id string) (entity.Album, error) { + for _, item := range m.items { + if item.ID == id { + return item, nil + } + } + return entity.Album{}, sql.ErrNoRows +} + +func (m mockRepository) Count(ctx context.Context) (int, error) { + return len(m.items), nil +} + +func (m mockRepository) Query(ctx context.Context, offset, limit int) ([]entity.Album, error) { + return m.items, nil +} + +func (m *mockRepository) Create(ctx context.Context, album entity.Album) error { + if album.Name == "error" { + return errCRUD + } + m.items = append(m.items, album) + return nil +} + +func (m *mockRepository) Update(ctx context.Context, album entity.Album) error { + if album.Name == "error" { + return errCRUD + } + for i, item := range m.items { + if item.ID == album.ID { + m.items[i] = album + break + } + } + return nil +} + +func (m *mockRepository) Delete(ctx context.Context, id string) error { + for i, item := range m.items { + if item.ID == id { + m.items[i] = m.items[len(m.items)-1] + m.items = m.items[:len(m.items)-1] + break + } + } + return nil +} diff --git a/internal/auth/api.go b/internal/auth/api.go new file mode 100644 index 0000000..7d201d4 --- /dev/null +++ b/internal/auth/api.go @@ -0,0 +1,35 @@ +package auth + +import ( + routing "github.com/go-ozzo/ozzo-routing/v2" + "github.com/qiangxue/go-rest-api/internal/errors" + "github.com/qiangxue/go-rest-api/pkg/log" +) + +// RegisterHandlers registers handlers for different HTTP requests. +func RegisterHandlers(rg *routing.RouteGroup, service Service, logger log.Logger) { + rg.Post("/login", login(service, logger)) +} + +// login returns a handler that handles user login request. +func login(service Service, logger log.Logger) routing.Handler { + return func(c *routing.Context) error { + var req struct { + Username string `json:"username"` + Password string `json:"password"` + } + + if err := c.Read(&req); err != nil { + logger.With(c.Request.Context()).Errorf("invalid request: %v", err) + return errors.BadRequest("") + } + + token, err := service.Login(c.Request.Context(), req.Username, req.Password) + if err != nil { + return err + } + return c.Write(struct { + Token string `json:"token"` + }{token}) + } +} diff --git a/internal/auth/api_test.go b/internal/auth/api_test.go new file mode 100644 index 0000000..4465966 --- /dev/null +++ b/internal/auth/api_test.go @@ -0,0 +1,34 @@ +package auth + +import ( + "context" + "github.com/qiangxue/go-rest-api/internal/errors" + "github.com/qiangxue/go-rest-api/internal/test" + "github.com/qiangxue/go-rest-api/pkg/log" + "net/http" + "testing" +) + +type mockService struct{} + +func (m mockService) Login(ctx context.Context, username, password string) (string, error) { + if username == "test" && password == "pass" { + return "token-100", nil + } + return "", errors.Unauthorized("") +} + +func TestAPI(t *testing.T) { + logger, _ := log.NewForTest() + router := test.MockRouter(logger) + RegisterHandlers(router.Group(""), mockService{}, logger) + + tests := []test.APITestCase{ + {"success", "POST", "/login", `{"username":"test","password":"pass"}`, nil, http.StatusOK, `{"token":"token-100"}`}, + {"bad credential", "POST", "/login", `{"username":"test","password":"wrong pass"}`, nil, http.StatusUnauthorized, ""}, + {"bad json", "POST", "/login", `"username":"test","password":"wrong pass"}`, nil, http.StatusBadRequest, ""}, + } + for _, tc := range tests { + test.Endpoint(t, router, tc) + } +} diff --git a/internal/auth/middleware.go b/internal/auth/middleware.go new file mode 100644 index 0000000..df48004 --- /dev/null +++ b/internal/auth/middleware.go @@ -0,0 +1,67 @@ +package auth + +import ( + "context" + "github.com/dgrijalva/jwt-go" + routing "github.com/go-ozzo/ozzo-routing/v2" + "github.com/go-ozzo/ozzo-routing/v2/auth" + "github.com/qiangxue/go-rest-api/internal/entity" + "github.com/qiangxue/go-rest-api/internal/errors" + "net/http" +) + +// Handler returns a JWT-based authentication middleware. +func Handler(verificationKey string) routing.Handler { + return auth.JWT(verificationKey, auth.JWTOptions{TokenHandler: handleToken}) +} + +// handleToken stores the user identity in the request context so that it can be accessed elsewhere. +func handleToken(c *routing.Context, token *jwt.Token) error { + ctx := WithUser( + c.Request.Context(), + token.Claims.(jwt.MapClaims)["id"].(string), + token.Claims.(jwt.MapClaims)["name"].(string), + ) + c.Request = c.Request.WithContext(ctx) + return nil +} + +type contextKey int + +const ( + userKey contextKey = iota +) + +// WithUser returns a context that contains the user identity from the given JWT. +func WithUser(ctx context.Context, id, name string) context.Context { + return context.WithValue(ctx, userKey, entity.User{ID: id, Name: name}) +} + +// CurrentUser returns the user identity from the given context. +// Nil is returned if no user identity is found in the context. +func CurrentUser(ctx context.Context) Identity { + if user, ok := ctx.Value(userKey).(entity.User); ok { + return user + } + return nil +} + +// MockAuthHandler creates a mock authentication middleware for testing purpose. +// If the request contains an Authorization header whose value is "TEST", then +// it considers the user is authenticated as "Tester" whose ID is "100". +// It fails the authentication otherwise. +func MockAuthHandler(c *routing.Context) error { + if c.Request.Header.Get("Authorization") != "TEST" { + return errors.Unauthorized("") + } + ctx := WithUser(c.Request.Context(), "100", "Tester") + c.Request = c.Request.WithContext(ctx) + return nil +} + +// MockAuthHeader returns an HTTP header that can pass the authentication check by MockAuthHandler. +func MockAuthHeader() http.Header { + header := http.Header{} + header.Add("Authorization", "TEST") + return header +} diff --git a/internal/auth/middleware_test.go b/internal/auth/middleware_test.go new file mode 100644 index 0000000..8c243fd --- /dev/null +++ b/internal/auth/middleware_test.go @@ -0,0 +1,54 @@ +package auth + +import ( + "context" + "github.com/dgrijalva/jwt-go" + "github.com/qiangxue/go-rest-api/internal/test" + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestCurrentUser(t *testing.T) { + ctx := context.Background() + assert.Nil(t, CurrentUser(ctx)) + ctx = WithUser(ctx, "100", "test") + identity := CurrentUser(ctx) + if assert.NotNil(t, identity) { + assert.Equal(t, "100", identity.GetID()) + assert.Equal(t, "test", identity.GetName()) + } +} + +func TestHandler(t *testing.T) { + assert.NotNil(t, Handler("test")) +} + +func Test_handleToken(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", nil) + ctx, _ := test.MockRoutingContext(req) + assert.Nil(t, CurrentUser(ctx.Request.Context())) + + err := handleToken(ctx, &jwt.Token{ + Claims: jwt.MapClaims{ + "id": "100", + "name": "test", + }, + }) + assert.Nil(t, err) + identity := CurrentUser(ctx.Request.Context()) + if assert.NotNil(t, identity) { + assert.Equal(t, "100", identity.GetID()) + assert.Equal(t, "test", identity.GetName()) + } +} + +func TestMocks(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", nil) + ctx, _ := test.MockRoutingContext(req) + assert.NotNil(t, MockAuthHandler(ctx)) + req.Header = MockAuthHeader() + ctx, _ = test.MockRoutingContext(req) + assert.Nil(t, MockAuthHandler(ctx)) + assert.NotNil(t, CurrentUser(ctx.Request.Context())) +} diff --git a/internal/auth/service.go b/internal/auth/service.go new file mode 100644 index 0000000..786e48f --- /dev/null +++ b/internal/auth/service.go @@ -0,0 +1,69 @@ +package auth + +import ( + "context" + "github.com/dgrijalva/jwt-go" + "github.com/qiangxue/go-rest-api/internal/entity" + "github.com/qiangxue/go-rest-api/internal/errors" + "github.com/qiangxue/go-rest-api/pkg/log" + "time" +) + +// Service encapsulates the authentication logic. +type Service interface { + // authenticate authenticates a user using username and password. + // It returns a JWT token if authentication succeeds. Otherwise, an error is returned. + Login(ctx context.Context, username, password string) (string, error) +} + +// Identity represents an authenticated user identity. +type Identity interface { + // GetID returns the user ID. + GetID() string + // GetName returns the user name. + GetName() string +} + +type service struct { + signingKey string + tokenExpiration int + logger log.Logger +} + +// NewService creates a new authentication service. +func NewService(signingKey string, tokenExpiration int, logger log.Logger) Service { + return service{signingKey, tokenExpiration, logger} +} + +// Login authenticates a user and generates a JWT token if authentication succeeds. +// Otherwise, an error is returned. +func (s service) Login(ctx context.Context, username, password string) (string, error) { + if identity := s.authenticate(ctx, username, password); identity != nil { + return s.generateJWT(identity) + } + return "", errors.Unauthorized("") +} + +// authenticate authenticates a user using username and password. +// If username and password are correct, an identity is returned. Otherwise, nil is returned. +func (s service) authenticate(ctx context.Context, username, password string) Identity { + logger := s.logger.With(ctx, "user", username) + + // TODO: the following authentication logic is only for demo purpose + if username == "demo" && password == "pass" { + logger.Infof("authentication successful") + return entity.User{ID: "100", Name: "demo"} + } + + logger.Infof("authentication failed") + return nil +} + +// generateJWT generates a JWT that encodes an identity. +func (s service) generateJWT(identity Identity) (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "id": identity.GetID(), + "name": identity.GetName(), + "exp": time.Now().Add(time.Duration(s.tokenExpiration) * time.Hour).Unix(), + }).SignedString([]byte(s.signingKey)) +} diff --git a/internal/auth/service_test.go b/internal/auth/service_test.go new file mode 100644 index 0000000..e39f0bf --- /dev/null +++ b/internal/auth/service_test.go @@ -0,0 +1,39 @@ +package auth + +import ( + "context" + "github.com/qiangxue/go-rest-api/internal/entity" + "github.com/qiangxue/go-rest-api/internal/errors" + "github.com/qiangxue/go-rest-api/pkg/log" + "github.com/stretchr/testify/assert" + "testing" +) + +func Test_service_Authenticate(t *testing.T) { + logger, _ := log.NewForTest() + s := NewService("test", 100, logger) + _, err := s.Login(context.Background(), "unknown", "bad") + assert.Equal(t, errors.Unauthorized(""), err) + token, err := s.Login(context.Background(), "demo", "pass") + assert.Nil(t, err) + assert.NotEmpty(t, token) +} + +func Test_service_authenticate(t *testing.T) { + logger, _ := log.NewForTest() + s := service{"test", 100, logger} + assert.Nil(t, s.authenticate(context.Background(), "unknown", "bad")) + assert.NotNil(t, s.authenticate(context.Background(), "demo", "pass")) +} + +func Test_service_GenerateJWT(t *testing.T) { + logger, _ := log.NewForTest() + s := service{"test", 100, logger} + token, err := s.generateJWT(entity.User{ + ID: "100", + Name: "demo", + }) + if assert.Nil(t, err) { + assert.NotEmpty(t, token) + } +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..6085209 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,67 @@ +package config + +import ( + "github.com/go-ozzo/ozzo-validation/v4" + "github.com/qiangxue/go-env" + "github.com/qiangxue/go-rest-api/pkg/log" + "gopkg.in/yaml.v2" + "io/ioutil" +) + +const ( + defaultServerPort = 8080 + defaultJWTExpirationHours = 72 +) + +// Config represents an application configuration. +type Config struct { + // the server port. Defaults to 8080 + ServerPort int `yaml:"server_port" env:"SERVER_PORT"` + // the data source name (DSN) for connecting to the database. required. + DSN string `yaml:"dsn" env:"DSN,secret"` + // JWT signing key. required. + JWTSigningKey string `yaml:"jwt_signing_key" env:"JWT_SIGNING_KEY,secret"` + // JWT verification key. required. + JWTVerificationKey string `yaml:"jwt_verification_key" env:"JWT_VERIFICATION_KEY,secret"` + // JWT expiration in hours. Defaults to 72 hours (3 days) + JWTExpiration int `yaml:"jwt_expiration" env:"JWT_EXPIRATION"` +} + +// Validate validates the application configuration. +func (c Config) Validate() error { + return validation.ValidateStruct(&c, + validation.Field(&c.DSN, validation.Required), + validation.Field(&c.JWTSigningKey, validation.Required), + validation.Field(&c.JWTVerificationKey, validation.Required), + ) +} + +// Load returns an application configuration which is populated from the given configuration file and environment variables. +func Load(file string, logger log.Logger) (*Config, error) { + // default config + c := Config{ + ServerPort: defaultServerPort, + JWTExpiration: defaultJWTExpirationHours, + } + + // load from YAML config file + bytes, err := ioutil.ReadFile(file) + if err != nil { + return nil, err + } + if err = yaml.Unmarshal(bytes, &c); err != nil { + return nil, err + } + + // load from environment variables prefixed with "APP_" + if err = env.New("APP_", logger.Infof).Load(&c); err != nil { + return nil, err + } + + // validation + if err = c.Validate(); err != nil { + return nil, err + } + + return &c, err +} diff --git a/internal/entity/album.go b/internal/entity/album.go new file mode 100644 index 0000000..ca20d14 --- /dev/null +++ b/internal/entity/album.go @@ -0,0 +1,13 @@ +package entity + +import ( + "time" +) + +// Album represents an album record. +type Album struct { + ID string `json:"id"` + Name string `json:"name"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/internal/entity/id.go b/internal/entity/id.go new file mode 100644 index 0000000..449611c --- /dev/null +++ b/internal/entity/id.go @@ -0,0 +1,8 @@ +package entity + +import "github.com/google/uuid" + +// GenerateID generates a unique ID that can be used as an identifier for an entity. +func GenerateID() string { + return uuid.New().String() +} diff --git a/internal/entity/user.go b/internal/entity/user.go new file mode 100644 index 0000000..1b918c8 --- /dev/null +++ b/internal/entity/user.go @@ -0,0 +1,17 @@ +package entity + +// User represents a user. +type User struct { + ID string + Name string +} + +// GetID returns the user ID. +func (u User) GetID() string { + return u.ID +} + +// GetName returns the user name. +func (u User) GetName() string { + return u.Name +} diff --git a/internal/errors/middleware.go b/internal/errors/middleware.go new file mode 100644 index 0000000..e8f800a --- /dev/null +++ b/internal/errors/middleware.go @@ -0,0 +1,68 @@ +package errors + +import ( + "database/sql" + "errors" + "fmt" + routing "github.com/go-ozzo/ozzo-routing/v2" + validation "github.com/go-ozzo/ozzo-validation/v4" + "github.com/qiangxue/go-rest-api/pkg/log" + "net/http" + "runtime/debug" +) + +// Handler creates a middleware that handles panics and errors encountered during HTTP request processing. +func Handler(logger log.Logger) routing.Handler { + return func(c *routing.Context) (err error) { + defer func() { + l := logger.With(c.Request.Context()) + if e := recover(); e != nil { + var ok bool + if err, ok = e.(error); !ok { + err = fmt.Errorf("%v", e) + } + + l.Errorf("recovered from panic (%v): %s", err, debug.Stack()) + } + + if err != nil { + res := buildErrorResponse(err) + if res.StatusCode() == http.StatusInternalServerError { + l.Errorf("encountered internal server error: %v", err) + } + c.Response.WriteHeader(res.StatusCode()) + if err = c.Write(res); err != nil { + l.Errorf("failed writing error response: %v", err) + } + c.Abort() // skip any pending handlers since an error has occurred + err = nil // return nil because the error is already handled + } + }() + return c.Next() + } +} + +// buildErrorResponse builds an error response from an error. +func buildErrorResponse(err error) ErrorResponse { + switch err.(type) { + case ErrorResponse: + return err.(ErrorResponse) + case validation.Errors: + return InvalidInput(err.(validation.Errors)) + case routing.HTTPError: + switch err.(routing.HTTPError).StatusCode() { + case http.StatusNotFound: + return NotFound("") + default: + return ErrorResponse{ + Status: err.(routing.HTTPError).StatusCode(), + Message: err.Error(), + } + } + } + + if errors.Is(err, sql.ErrNoRows) { + return NotFound("") + } + return InternalServerError("") +} diff --git a/internal/errors/middleware_test.go b/internal/errors/middleware_test.go new file mode 100644 index 0000000..f26c0a9 --- /dev/null +++ b/internal/errors/middleware_test.go @@ -0,0 +1,93 @@ +package errors + +import ( + "database/sql" + "fmt" + routing "github.com/go-ozzo/ozzo-routing/v2" + validation "github.com/go-ozzo/ozzo-validation/v4" + "github.com/qiangxue/go-rest-api/pkg/log" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestHandler(t *testing.T) { + t.Run("normal processing", func(t *testing.T) { + logger, entries := log.NewForTest() + handler := Handler(logger) + ctx, res := buildContext(handler, handlerOK) + assert.Nil(t, ctx.Next()) + assert.Zero(t, entries.Len()) + assert.Equal(t, http.StatusOK, res.Code) + }) + + t.Run("error processing", func(t *testing.T) { + logger, entries := log.NewForTest() + handler := Handler(logger) + ctx, res := buildContext(handler, handlerError) + assert.Nil(t, ctx.Next()) + assert.Equal(t, 1, entries.Len()) + assert.Equal(t, http.StatusInternalServerError, res.Code) + }) + + t.Run("HTTP error processing", func(t *testing.T) { + logger, entries := log.NewForTest() + handler := Handler(logger) + ctx, res := buildContext(handler, handlerHTTPError) + assert.Nil(t, ctx.Next()) + assert.Equal(t, 0, entries.Len()) + assert.Equal(t, http.StatusNotFound, res.Code) + }) + + t.Run("panic processing", func(t *testing.T) { + logger, entries := log.NewForTest() + handler := Handler(logger) + ctx, res := buildContext(handler, handlerPanic) + assert.Nil(t, ctx.Next()) + assert.Equal(t, 2, entries.Len()) + assert.Equal(t, http.StatusInternalServerError, res.Code) + }) +} + +func Test_buildErrorResponse(t *testing.T) { + res := NotFound("") + assert.Equal(t, res, buildErrorResponse(res)) + + res = buildErrorResponse(routing.NewHTTPError(http.StatusNotFound)) + assert.Equal(t, http.StatusNotFound, res.Status) + + res = buildErrorResponse(validation.Errors{}) + assert.Equal(t, http.StatusBadRequest, res.Status) + + res = buildErrorResponse(routing.NewHTTPError(http.StatusForbidden)) + assert.Equal(t, http.StatusForbidden, res.Status) + + res = buildErrorResponse(sql.ErrNoRows) + assert.Equal(t, http.StatusNotFound, res.Status) + + res = buildErrorResponse(fmt.Errorf("test")) + assert.Equal(t, http.StatusInternalServerError, res.Status) +} + +func buildContext(handlers ...routing.Handler) (*routing.Context, *httptest.ResponseRecorder) { + res := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://127.0.0.1/users", nil) + return routing.NewContext(res, req, handlers...), res +} + +func handlerOK(c *routing.Context) error { + return c.Write("test") +} + +func handlerError(c *routing.Context) error { + return fmt.Errorf("abc") +} + +func handlerHTTPError(c *routing.Context) error { + return NotFound("") +} + +func handlerPanic(c *routing.Context) error { + panic("xyz") +} diff --git a/internal/errors/response.go b/internal/errors/response.go new file mode 100644 index 0000000..0dc06bc --- /dev/null +++ b/internal/errors/response.go @@ -0,0 +1,106 @@ +package errors + +import ( + validation "github.com/go-ozzo/ozzo-validation/v4" + "net/http" + "sort" +) + +// ErrorResponse is the response that represents an error. +type ErrorResponse struct { + Status int `json:"status"` + Message string `json:"message"` + Details interface{} `json:"details,omitempty"` +} + +// Error is required by the error interface. +func (e ErrorResponse) Error() string { + return e.Message +} + +// StatusCode is required by routing.HTTPError interface. +func (e ErrorResponse) StatusCode() int { + return e.Status +} + +// InternalServerError creates a new error response representing an internal server error (HTTP 500) +func InternalServerError(msg string) ErrorResponse { + if msg == "" { + msg = "We encountered an error while processing your request." + } + return ErrorResponse{ + Status: http.StatusInternalServerError, + Message: msg, + } +} + +// NotFound creates a new error response representing a resource-not-found error (HTTP 404) +func NotFound(msg string) ErrorResponse { + if msg == "" { + msg = "The requested resource was not found." + } + return ErrorResponse{ + Status: http.StatusNotFound, + Message: msg, + } +} + +// Unauthorized creates a new error response representing an authentication/authorization failure (HTTP 401) +func Unauthorized(msg string) ErrorResponse { + if msg == "" { + msg = "You are not authenticated to perform the requested action." + } + return ErrorResponse{ + Status: http.StatusUnauthorized, + Message: msg, + } +} + +// Forbidden creates a new error response representing an authorization failure (HTTP 403) +func Forbidden(msg string) ErrorResponse { + if msg == "" { + msg = "You are not authorized to perform the requested action." + } + return ErrorResponse{ + Status: http.StatusForbidden, + Message: msg, + } +} + +// BadRequest creates a new error response representing a bad request (HTTP 400) +func BadRequest(msg string) ErrorResponse { + if msg == "" { + msg = "Your request is in a bad format." + } + return ErrorResponse{ + Status: http.StatusBadRequest, + Message: msg, + } +} + +type invalidField struct { + Field string `json:"field"` + Error string `json:"error"` +} + +// InvalidInput creates a new error response representing a data validation error (HTTP 400). +func InvalidInput(errs validation.Errors) ErrorResponse { + var details []invalidField + var fields []string + for field := range errs { + fields = append(fields, field) + } + sort.Strings(fields) + for _, field := range fields { + details = append(details, invalidField{ + Field: field, + Error: errs[field].Error(), + }) + } + + return ErrorResponse{ + Status: http.StatusBadRequest, + Message: "There is some problem with the data you submitted.", + Details: details, + } +} diff --git a/internal/errors/response_test.go b/internal/errors/response_test.go new file mode 100644 index 0000000..6d506ab --- /dev/null +++ b/internal/errors/response_test.go @@ -0,0 +1,72 @@ +package errors + +import ( + "fmt" + validation "github.com/go-ozzo/ozzo-validation/v4" + "github.com/stretchr/testify/assert" + "net/http" + "testing" +) + +func TestErrorResponse_Error(t *testing.T) { + e := ErrorResponse{ + Message: "abc", + } + assert.Equal(t, "abc", e.Error()) +} + +func TestErrorResponse_StatusCode(t *testing.T) { + e := ErrorResponse{ + Status: 400, + } + assert.Equal(t, 400, e.StatusCode()) +} + +func TestInternalServerError(t *testing.T) { + res := InternalServerError("test") + assert.Equal(t, http.StatusInternalServerError, res.StatusCode()) + assert.Equal(t, "test", res.Error()) + res = InternalServerError("") + assert.NotEmpty(t, res.Error()) +} + +func TestNotFound(t *testing.T) { + res := NotFound("test") + assert.Equal(t, http.StatusNotFound, res.StatusCode()) + assert.Equal(t, "test", res.Error()) + res = NotFound("") + assert.NotEmpty(t, res.Error()) +} + +func TestUnauthorized(t *testing.T) { + res := Unauthorized("test") + assert.Equal(t, http.StatusUnauthorized, res.StatusCode()) + assert.Equal(t, "test", res.Error()) + res = Unauthorized("") + assert.NotEmpty(t, res.Error()) +} + +func TestForbidden(t *testing.T) { + res := Forbidden("test") + assert.Equal(t, http.StatusForbidden, res.StatusCode()) + assert.Equal(t, "test", res.Error()) + res = Forbidden("") + assert.NotEmpty(t, res.Error()) +} + +func TestBadRequest(t *testing.T) { + res := BadRequest("test") + assert.Equal(t, http.StatusBadRequest, res.StatusCode()) + assert.Equal(t, "test", res.Error()) + res = BadRequest("") + assert.NotEmpty(t, res.Error()) +} + +func TestInvalidInput(t *testing.T) { + err := InvalidInput(validation.Errors{ + "xyz": fmt.Errorf("2"), + "abc": fmt.Errorf("1"), + }) + assert.Equal(t, http.StatusBadRequest, err.Status) + assert.Equal(t, []invalidField{{"abc", "1"}, {"xyz", "2"}}, err.Details) +} diff --git a/internal/healthcheck/api.go b/internal/healthcheck/api.go new file mode 100644 index 0000000..ebc39da --- /dev/null +++ b/internal/healthcheck/api.go @@ -0,0 +1,15 @@ +package healthcheck + +import routing "github.com/go-ozzo/ozzo-routing/v2" + +// RegisterHandlers registers the handlers that perform healthchecks. +func RegisterHandlers(r *routing.Router, version string) { + r.To("GET,HEAD", "/healthcheck", healthcheck(version)) +} + +// healthcheck responds to a healthcheck request. +func healthcheck(version string) routing.Handler { + return func(c *routing.Context) error { + return c.Write("OK " + version) + } +} diff --git a/internal/healthcheck/api_test.go b/internal/healthcheck/api_test.go new file mode 100644 index 0000000..52fd36c --- /dev/null +++ b/internal/healthcheck/api_test.go @@ -0,0 +1,17 @@ +package healthcheck + +import ( + "github.com/qiangxue/go-rest-api/internal/test" + "github.com/qiangxue/go-rest-api/pkg/log" + "net/http" + "testing" +) + +func TestAPI(t *testing.T) { + logger, _ := log.NewForTest() + router := test.MockRouter(logger) + RegisterHandlers(router, "0.9.0") + test.Endpoint(t, router, test.APITestCase{ + "ok", "GET", "/healthcheck", "", nil, http.StatusOK, `"OK 0.9.0"`, + }) +} diff --git a/internal/test/api.go b/internal/test/api.go new file mode 100644 index 0000000..227386a --- /dev/null +++ b/internal/test/api.go @@ -0,0 +1,45 @@ +package test + +import ( + "bytes" + routing "github.com/go-ozzo/ozzo-routing/v2" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// APITestCase represents the data needed to describe an API test case. +type APITestCase struct { + Name string + Method, URL string + Body string + Header http.Header + WantStatus int + WantResponse string +} + +// Endpoint tests an HTTP endpoint using the given APITestCase spec. +func Endpoint(t *testing.T, router *routing.Router, tc APITestCase) { + t.Run(tc.Name, func(t *testing.T) { + req, _ := http.NewRequest(tc.Method, tc.URL, bytes.NewBufferString(tc.Body)) + if tc.Header != nil { + req.Header = tc.Header + } + res := httptest.NewRecorder() + if req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "application/json") + } + router.ServeHTTP(res, req) + assert.Equal(t, tc.WantStatus, res.Code, "status mismatch") + if tc.WantResponse != "" { + pattern := strings.Trim(tc.WantResponse, "*") + if pattern != tc.WantResponse { + assert.Contains(t, res.Body.String(), pattern, "response mismatch") + } else { + assert.JSONEq(t, tc.WantResponse, res.Body.String(), "response mismatch") + } + } + }) +} diff --git a/internal/test/db.go b/internal/test/db.go new file mode 100644 index 0000000..a1d79f3 --- /dev/null +++ b/internal/test/db.go @@ -0,0 +1,53 @@ +package test + +import ( + dbx "github.com/go-ozzo/ozzo-dbx" + _ "github.com/lib/pq" // initialize posgresql for test + "github.com/qiangxue/go-rest-api/internal/config" + "github.com/qiangxue/go-rest-api/pkg/dbcontext" + "github.com/qiangxue/go-rest-api/pkg/log" + "path" + "runtime" + "testing" +) + +var db *dbcontext.DB + +// DB returns the database connection for testing purpose. +func DB(t *testing.T) *dbcontext.DB { + if db != nil { + return db + } + logger, _ := log.NewForTest() + dir := getSourcePath() + cfg, err := config.Load(dir+"/../../config/local.yml", logger) + if err != nil { + t.Error(err) + t.FailNow() + } + dbc, err := dbx.MustOpen("postgres", cfg.DSN) + if err != nil { + t.Error(err) + t.FailNow() + } + dbc.LogFunc = logger.Infof + db = dbcontext.New(dbc) + return db +} + +// ResetTables truncates all data in the specified tables. +func ResetTables(t *testing.T, db *dbcontext.DB, tables ...string) { + for _, table := range tables { + _, err := db.DB().TruncateTable(table).Execute() + if err != nil { + t.Error(err) + t.FailNow() + } + } +} + +// getSourcePath returns the directory containing the source code that is calling this function. +func getSourcePath() string { + _, filename, _, _ := runtime.Caller(1) + return path.Dir(filename) +} diff --git a/internal/test/mock.go b/internal/test/mock.go new file mode 100644 index 0000000..c6058c9 --- /dev/null +++ b/internal/test/mock.go @@ -0,0 +1,35 @@ +package test + +import ( + routing "github.com/go-ozzo/ozzo-routing/v2" + "github.com/go-ozzo/ozzo-routing/v2/content" + "github.com/go-ozzo/ozzo-routing/v2/cors" + "github.com/qiangxue/go-rest-api/internal/errors" + "github.com/qiangxue/go-rest-api/pkg/accesslog" + "github.com/qiangxue/go-rest-api/pkg/log" + "net/http" + "net/http/httptest" +) + +// MockRoutingContext creates a routing.Conext for testing handlers. +func MockRoutingContext(req *http.Request) (*routing.Context, *httptest.ResponseRecorder) { + res := httptest.NewRecorder() + if req.Header.Get("Content-Type") == "" { + req.Header.Set("Content-Type", "application/json") + } + ctx := routing.NewContext(res, req) + ctx.SetDataWriter(&content.JSONDataWriter{}) + return ctx, res +} + +// MockRouter creates a routing.Router for testing APIs. +func MockRouter(logger log.Logger) *routing.Router { + router := routing.New() + router.Use( + accesslog.Handler(logger), + errors.Handler(logger), + content.TypeNegotiator(content.JSON), + cors.Handler(cors.AllowAll), + ) + return router +} diff --git a/migrations/20191217202658_init.down.sql b/migrations/20191217202658_init.down.sql new file mode 100644 index 0000000..de05f1b --- /dev/null +++ b/migrations/20191217202658_init.down.sql @@ -0,0 +1 @@ +DROP TABLE album; \ No newline at end of file diff --git a/migrations/20191217202658_init.up.sql b/migrations/20191217202658_init.up.sql new file mode 100644 index 0000000..1b89dd2 --- /dev/null +++ b/migrations/20191217202658_init.up.sql @@ -0,0 +1,7 @@ +CREATE TABLE album +( + id VARCHAR PRIMARY KEY, + name VARCHAR NOT NULL, + created_at TIMESTAMP NOT NULL, + updated_at TIMESTAMP NOT NULL +); diff --git a/pkg/accesslog/middleware.go b/pkg/accesslog/middleware.go new file mode 100644 index 0000000..f9e70c9 --- /dev/null +++ b/pkg/accesslog/middleware.go @@ -0,0 +1,34 @@ +// Package accesslog provides a middleware that records every RESTful API call in a log message. +package accesslog + +import ( + routing "github.com/go-ozzo/ozzo-routing/v2" + "github.com/go-ozzo/ozzo-routing/v2/access" + "github.com/qiangxue/go-rest-api/pkg/log" + "net/http" + "time" +) + +// Handler returns a middleware that records an access log message for every HTTP request being processed. +func Handler(logger log.Logger) routing.Handler { + return func(c *routing.Context) error { + start := time.Now() + + rw := &access.LogResponseWriter{ResponseWriter: c.Response, Status: http.StatusOK} + c.Response = rw + + // associate request ID and session ID with the request context + // so that they can be added to the log messages + ctx := c.Request.Context() + ctx = log.WithRequest(ctx, c.Request) + c.Request = c.Request.WithContext(ctx) + + err := c.Next() + + // generate an access log message + logger.With(ctx, "duration", time.Now().Sub(start).Milliseconds(), "status", rw.Status). + Infof("%s %s %s %d %d", c.Request.Method, c.Request.URL.Path, c.Request.Proto, rw.Status, rw.BytesWritten) + + return err + } +} diff --git a/pkg/accesslog/middleware_test.go b/pkg/accesslog/middleware_test.go new file mode 100644 index 0000000..5787686 --- /dev/null +++ b/pkg/accesslog/middleware_test.go @@ -0,0 +1,24 @@ +package accesslog + +import ( + routing "github.com/go-ozzo/ozzo-routing/v2" + "github.com/qiangxue/go-rest-api/pkg/log" + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "testing" +) + +func TestHandler(t *testing.T) { + res := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://127.0.0.1/users", nil) + ctx := routing.NewContext(res, req) + + logger, entries := log.NewForTest() + handler := Handler(logger) + err := handler(ctx) + + assert.Nil(t, err) + assert.Equal(t, 1, entries.Len()) + assert.Equal(t, "GET /users HTTP/1.1 200 0", entries.All()[0].Message) +} diff --git a/pkg/dbcontext/db.go b/pkg/dbcontext/db.go new file mode 100644 index 0000000..7d3f904 --- /dev/null +++ b/pkg/dbcontext/db.go @@ -0,0 +1,64 @@ +// Package dbcontext provides DB transaction support for transactions tha span method calls of multiple +// repositories and services. +package dbcontext + +import ( + "context" + + dbx "github.com/go-ozzo/ozzo-dbx" + routing "github.com/go-ozzo/ozzo-routing/v2" +) + +// DB represents a DB connection that can be used to run SQL queries. +type DB struct { + db *dbx.DB +} + +// TransactionFunc represents a function that will start a transaction and run the given function. +type TransactionFunc func(ctx context.Context, f func(ctx context.Context) error) error + +type contextKey int + +const ( + txKey contextKey = iota +) + +// New returns a new DB connection that wraps the given dbx.DB instance. +func New(db *dbx.DB) *DB { + return &DB{db} +} + +// DB returns the dbx.DB wrapped by this object. +func (db *DB) DB() *dbx.DB { + return db.db +} + +// With returns a Builder that can be used to build and execute SQL queries. +// With will return the transaction if it is found in the given context. +// Otherwise it will return a DB connection associated with the context. +func (db *DB) With(ctx context.Context) dbx.Builder { + if tx, ok := ctx.Value(txKey).(*dbx.Tx); ok { + return tx + } + return db.db.WithContext(ctx) +} + +// Transactional starts a transaction and calls the given function with a context storing the transaction. +// The transaction associated with the context can be accesse via With(). +func (db *DB) Transactional(ctx context.Context, f func(ctx context.Context) error) error { + return db.db.TransactionalContext(ctx, nil, func(tx *dbx.Tx) error { + return f(context.WithValue(ctx, txKey, tx)) + }) +} + +// TransactionHandler returns a middleware that starts a transaction. +// The transaction started is kept in the context and can be accessed via With(). +func (db *DB) TransactionHandler() routing.Handler { + return func(c *routing.Context) error { + return db.db.TransactionalContext(c.Request.Context(), nil, func(tx *dbx.Tx) error { + ctx := context.WithValue(c.Request.Context(), txKey, tx) + c.Request = c.Request.WithContext(ctx) + return c.Next() + }) + } +} diff --git a/pkg/dbcontext/db_test.go b/pkg/dbcontext/db_test.go new file mode 100644 index 0000000..0d92688 --- /dev/null +++ b/pkg/dbcontext/db_test.go @@ -0,0 +1,141 @@ +package dbcontext + +import ( + "context" + "database/sql" + dbx "github.com/go-ozzo/ozzo-dbx" + routing "github.com/go-ozzo/ozzo-routing/v2" + _ "github.com/lib/pq" // initialize posgresql for test + "github.com/stretchr/testify/assert" + "net/http" + "net/http/httptest" + "os" + "testing" +) + +const DSN = "postgres://127.0.0.1/go_restful?sslmode=disable&user=postgres&password=postgres" + +func TestNew(t *testing.T) { + runDBTest(t, func(db *dbx.DB) { + dbc := New(db) + assert.NotNil(t, dbc) + assert.Equal(t, db, dbc.DB()) + }) +} + +func TestDB_Transactional(t *testing.T) { + runDBTest(t, func(db *dbx.DB) { + assert.Zero(t, runCountQuery(t, db)) + dbc := New(db) + + // successful transaction + err := dbc.Transactional(context.Background(), func(ctx context.Context) error { + _, err := dbc.With(ctx).Insert("dbcontexttest", dbx.Params{"id": "1", "name": "name1"}).Execute() + assert.Nil(t, err) + _, err = dbc.With(ctx).Insert("dbcontexttest", dbx.Params{"id": "2", "name": "name2"}).Execute() + assert.Nil(t, err) + return nil + }) + assert.Nil(t, err) + assert.Equal(t, 2, runCountQuery(t, db)) + + // failed transaction + err = dbc.Transactional(context.Background(), func(ctx context.Context) error { + _, err := dbc.With(ctx).Insert("dbcontexttest", dbx.Params{"id": "3", "name": "name1"}).Execute() + assert.Nil(t, err) + _, err = dbc.With(ctx).Insert("dbcontexttest", dbx.Params{"id": "4", "name": "name2"}).Execute() + assert.Nil(t, err) + return sql.ErrNoRows + }) + assert.Equal(t, sql.ErrNoRows, err) + assert.Equal(t, 2, runCountQuery(t, db)) + + // failed transaction, but queries made outside of the transaction + err = dbc.Transactional(context.Background(), func(ctx context.Context) error { + _, err := dbc.With(context.Background()).Insert("dbcontexttest", dbx.Params{"id": "3", "name": "name1"}).Execute() + assert.Nil(t, err) + _, err = dbc.With(context.Background()).Insert("dbcontexttest", dbx.Params{"id": "4", "name": "name2"}).Execute() + assert.Nil(t, err) + return sql.ErrNoRows + }) + assert.Equal(t, sql.ErrNoRows, err) + assert.Equal(t, 4, runCountQuery(t, db)) + }) +} + +func TestDB_TransactionHandler(t *testing.T) { + runDBTest(t, func(db *dbx.DB) { + assert.Zero(t, runCountQuery(t, db)) + dbc := New(db) + txHandler := dbc.TransactionHandler() + + // successful transaction + { + res := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://127.0.0.1/users", nil) + err := routing.NewContext(res, req, txHandler, func(c *routing.Context) error { + ctx := c.Request.Context() + _, err := dbc.With(ctx).Insert("dbcontexttest", dbx.Params{"id": "1", "name": "name1"}).Execute() + assert.Nil(t, err) + _, err = dbc.With(ctx).Insert("dbcontexttest", dbx.Params{"id": "2", "name": "name2"}).Execute() + assert.Nil(t, err) + return nil + }).Next() + assert.Nil(t, err) + assert.Equal(t, 2, runCountQuery(t, db)) + } + + // failed transaction + { + res := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "http://127.0.0.1/users", nil) + err := routing.NewContext(res, req, txHandler, func(c *routing.Context) error { + ctx := c.Request.Context() + _, err := dbc.With(ctx).Insert("dbcontexttest", dbx.Params{"id": "3", "name": "name1"}).Execute() + assert.Nil(t, err) + _, err = dbc.With(ctx).Insert("dbcontexttest", dbx.Params{"id": "4", "name": "name2"}).Execute() + assert.Nil(t, err) + return sql.ErrNoRows + }).Next() + assert.Equal(t, err, sql.ErrNoRows) + assert.Equal(t, 2, runCountQuery(t, db)) + } + }) +} + +func runDBTest(t *testing.T, f func(db *dbx.DB)) { + dsn, ok := os.LookupEnv("APP_DSN") + if !ok { + dsn = DSN + } + db, err := dbx.MustOpen("postgres", dsn) + if err != nil { + t.Error(err) + t.FailNow() + } + defer func() { + _ = db.Close() + }() + + sqls := []string{ + "CREATE TABLE IF NOT EXISTS dbcontexttest (id VARCHAR PRIMARY KEY, name VARCHAR)", + "TRUNCATE dbcontexttest", + } + for _, s := range sqls { + _, err = db.NewQuery(s).Execute() + if err != nil { + t.Error(err, " with SQL: ", s) + t.FailNow() + } + } + + f(db) +} + +func runCountQuery(t *testing.T, db *dbx.DB) int { + var count int + err := db.NewQuery("SELECT COUNT(*) FROM dbcontexttest").Row(&count) + assert.Nil(t, err) + return count + +} diff --git a/pkg/log/logger.go b/pkg/log/logger.go new file mode 100644 index 0000000..be6c3ee --- /dev/null +++ b/pkg/log/logger.go @@ -0,0 +1,104 @@ +// Package log provides context-aware and structured logging capabilities. +package log + +import ( + "context" + "github.com/google/uuid" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" + "net/http" +) + +// Logger is a logger that supports log levels, context and structured logging. +type Logger interface { + // With returns a logger based off the root logger and decorates it with the given context and arguments. + With(ctx context.Context, args ...interface{}) Logger + + // Debug uses fmt.Sprint to construct and log a message at DEBUG level + Debug(args ...interface{}) + // Info uses fmt.Sprint to construct and log a message at INFO level + Info(args ...interface{}) + // Error uses fmt.Sprint to construct and log a message at ERROR level + Error(args ...interface{}) + + // Debugf uses fmt.Sprintf to construct and log a message at DEBUG level + Debugf(format string, args ...interface{}) + // Infof uses fmt.Sprintf to construct and log a message at INFO level + Infof(format string, args ...interface{}) + // Errorf uses fmt.Sprintf to construct and log a message at ERROR level + Errorf(format string, args ...interface{}) +} + +type logger struct { + *zap.SugaredLogger +} + +type contextKey int + +const ( + requestIDKey contextKey = iota + correlationIDKey +) + +// New creates a new logger using the default configuration. +func New() Logger { + l, _ := zap.NewProduction() + return NewWithZap(l) +} + +// NewWithZap creates a new logger using the preconfigured zap logger. +func NewWithZap(l *zap.Logger) Logger { + return &logger{l.Sugar()} +} + +// NewForTest returns a new logger and the corresponding observed logs which can be used in unit tests to verify log entries. +func NewForTest() (Logger, *observer.ObservedLogs) { + core, recorded := observer.New(zapcore.InfoLevel) + return NewWithZap(zap.New(core)), recorded +} + +// With returns a logger based off the root logger and decorates it with the given context and arguments. +// +// If the context contains request ID and/or correlation ID information (recorded via WithRequestID() +// and WithCorrelationID()), they will be added to every log message generated by the new logger. +// +// The arguments should be specified as a sequence of name, value pairs with names being strings. +// The arguments will also be added to every log message generated by the logger. +func (l *logger) With(ctx context.Context, args ...interface{}) Logger { + if ctx != nil { + if id, ok := ctx.Value(requestIDKey).(string); ok { + args = append(args, zap.String("request_id", id)) + } + if id, ok := ctx.Value(correlationIDKey).(string); ok { + args = append(args, zap.String("correlation_id", id)) + } + } + if len(args) > 0 { + return &logger{l.SugaredLogger.With(args...)} + } + return l +} + +// WithRequest returns a context which knows the request ID and correlation ID in the given request. +func WithRequest(ctx context.Context, req *http.Request) context.Context { + id := getRequestID(req) + if id == "" { + id = uuid.New().String() + } + ctx = context.WithValue(ctx, requestIDKey, id) + if id := getCorrelationID(req); id != "" { + ctx = context.WithValue(ctx, correlationIDKey, id) + } + return ctx +} + +// getCorrelationID extracts the correlation ID from the HTTP request +func getCorrelationID(req *http.Request) string { + return req.Header.Get("X-Correlation-ID") +} + +// getRequestID extracts the correlation ID from the HTTP request +func getRequestID(req *http.Request) string { + return req.Header.Get("X-Request-ID") +} diff --git a/pkg/log/logger_test.go b/pkg/log/logger_test.go new file mode 100644 index 0000000..0f8787b --- /dev/null +++ b/pkg/log/logger_test.go @@ -0,0 +1,83 @@ +package log + +import ( + "bytes" + "context" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "net/http" + "reflect" + "testing" +) + +func TestNew(t *testing.T) { + assert.NotNil(t, New()) +} + +func TestNewWithZap(t *testing.T) { + zl, _ := zap.NewProduction() + l := NewWithZap(zl) + assert.NotNil(t, l) +} + +func TestWithRequest(t *testing.T) { + req := buildRequest("abc", "123") + ctx := WithRequest(context.Background(), req) + assert.Equal(t, "abc", ctx.Value(requestIDKey).(string)) + assert.Equal(t, "123", ctx.Value(correlationIDKey).(string)) + + req = buildRequest("", "123") + ctx = WithRequest(context.Background(), req) + assert.NotEmpty(t, ctx.Value(requestIDKey).(string)) + assert.Equal(t, "123", ctx.Value(correlationIDKey).(string)) +} + +func Test_getCorrelationID(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", bytes.NewBufferString("")) + assert.Empty(t, getCorrelationID(req)) + req.Header.Set("X-Correlation-ID", "test") + assert.Equal(t, "test", getCorrelationID(req)) +} + +func Test_getRequestID(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com", bytes.NewBufferString("")) + assert.Empty(t, getRequestID(req)) + req.Header.Set("X-Request-ID", "test") + assert.Equal(t, "test", getRequestID(req)) +} + +func Test_logger_With(t *testing.T) { + l := New() + l2 := l.With(nil) + assert.True(t, reflect.DeepEqual(l2, l)) + + req := buildRequest("abc", "123") + ctx := WithRequest(context.Background(), req) + l3 := l.With(ctx) + assert.False(t, reflect.DeepEqual(l3, l2)) +} + +func buildRequest(requestID, correlationID string) *http.Request { + req, _ := http.NewRequest("GET", "http://example.com", bytes.NewBufferString("")) + if requestID != "" { + req.Header.Set("X-Request-ID", requestID) + } + if correlationID != "" { + req.Header.Set("X-Correlation-ID", correlationID) + } + return req +} + +func TestNewForTest(t *testing.T) { + logger, entries := NewForTest() + assert.Equal(t, 0, entries.Len()) + logger.Info("msg 1") + assert.Equal(t, 1, entries.Len()) + logger.Info("msg 2") + logger.Info("msg 3") + assert.Equal(t, 3, entries.Len()) + entries.TakeAll() + assert.Equal(t, 0, entries.Len()) + logger.Info("msg 4") + assert.Equal(t, 1, entries.Len()) +} diff --git a/pkg/pagination/pages.go b/pkg/pagination/pages.go new file mode 100644 index 0000000..97f6bca --- /dev/null +++ b/pkg/pagination/pages.go @@ -0,0 +1,146 @@ +// Package pagination provides support for pagination requests and responses. +package pagination + +import ( + "fmt" + "net/http" + "strconv" + "strings" +) + +var ( + // DefaultPageSize specifies the default page size + DefaultPageSize = 100 + // MaxPageSize specifies the maximum page size + MaxPageSize = 1000 + // PageVar specifies the query parameter name for page number + PageVar = "page" + // PageSizeVar specifies the query parameter name for page size + PageSizeVar = "per_page" +) + +// Pages represents a paginated list of data items. +type Pages struct { + Page int `json:"page"` + PerPage int `json:"per_page"` + PageCount int `json:"page_count"` + TotalCount int `json:"total_count"` + Items interface{} `json:"items"` +} + +// New creates a new Pages instance. +// The page parameter is 1-based and refers to the current page index/number. +// The perPage parameter refers to the number of items on each page. +// And the total parameter specifies the total number of data items. +// If total is less than 0, it means total is unknown. +func New(page, perPage, total int) *Pages { + if perPage <= 0 { + perPage = DefaultPageSize + } + if perPage > MaxPageSize { + perPage = MaxPageSize + } + pageCount := -1 + if total >= 0 { + pageCount = (total + perPage - 1) / perPage + if page > pageCount { + page = pageCount + } + } + if page < 1 { + page = 1 + } + + return &Pages{ + Page: page, + PerPage: perPage, + TotalCount: total, + PageCount: pageCount, + } +} + +// NewFromRequest creates a Pages object using the query parameters found in the given HTTP request. +// count stands for the total number of items. Use -1 if this is unknown. +func NewFromRequest(req *http.Request, count int) *Pages { + page := parseInt(req.URL.Query().Get(PageVar), 1) + perPage := parseInt(req.URL.Query().Get(PageSizeVar), DefaultPageSize) + return New(page, perPage, count) +} + +// parseInt parses a string into an integer. If parsing is failed, defaultValue will be returned. +func parseInt(value string, defaultValue int) int { + if value == "" { + return defaultValue + } + if result, err := strconv.Atoi(value); err == nil { + return result + } + return defaultValue +} + +// Offset returns the OFFSET value that can be used in a SQL statement. +func (p *Pages) Offset() int { + return (p.Page - 1) * p.PerPage +} + +// Limit returns the LIMIT value that can be used in a SQL statement. +func (p *Pages) Limit() int { + return p.PerPage +} + +// BuildLinkHeader returns an HTTP header containing the links about the pagination. +func (p *Pages) BuildLinkHeader(baseURL string, defaultPerPage int) string { + links := p.BuildLinks(baseURL, defaultPerPage) + header := "" + if links[0] != "" { + header += fmt.Sprintf("<%v>; rel=\"first\", ", links[0]) + header += fmt.Sprintf("<%v>; rel=\"prev\"", links[1]) + } + if links[2] != "" { + if header != "" { + header += ", " + } + header += fmt.Sprintf("<%v>; rel=\"next\"", links[2]) + if links[3] != "" { + header += fmt.Sprintf(", <%v>; rel=\"last\"", links[3]) + } + } + return header +} + +// BuildLinks returns the first, prev, next, and last links corresponding to the pagination. +// A link could be an empty string if it is not needed. +// For example, if the pagination is at the first page, then both first and prev links +// will be empty. +func (p *Pages) BuildLinks(baseURL string, defaultPerPage int) [4]string { + var links [4]string + pageCount := p.PageCount + page := p.Page + if pageCount >= 0 && page > pageCount { + page = pageCount + } + if strings.Contains(baseURL, "?") { + baseURL += "&" + } else { + baseURL += "?" + } + if page > 1 { + links[0] = fmt.Sprintf("%v%v=%v", baseURL, PageVar, 1) + links[1] = fmt.Sprintf("%v%v=%v", baseURL, PageVar, page-1) + } + if pageCount >= 0 && page < pageCount { + links[2] = fmt.Sprintf("%v%v=%v", baseURL, PageVar, page+1) + links[3] = fmt.Sprintf("%v%v=%v", baseURL, PageVar, pageCount) + } else if pageCount < 0 { + links[2] = fmt.Sprintf("%v%v=%v", baseURL, PageVar, page+1) + } + if perPage := p.PerPage; perPage != defaultPerPage { + for i := 0; i < 4; i++ { + if links[i] != "" { + links[i] += fmt.Sprintf("&%v=%v", PageSizeVar, perPage) + } + } + } + + return links +} diff --git a/pkg/pagination/pages_test.go b/pkg/pagination/pages_test.go new file mode 100644 index 0000000..a1aa027 --- /dev/null +++ b/pkg/pagination/pages_test.go @@ -0,0 +1,102 @@ +package pagination + +import ( + "bytes" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNew(t *testing.T) { + tests := []struct { + tag string + page, perPage, total int + expectedPage, expectedPerPage, expectedTotal, pageCount, offset, limit int + }{ + // varying page + {"t1", 1, 20, 50, 1, 20, 50, 3, 0, 20}, + {"t2", 2, 20, 50, 2, 20, 50, 3, 20, 20}, + {"t3", 3, 20, 50, 3, 20, 50, 3, 40, 20}, + {"t4", 4, 20, 50, 3, 20, 50, 3, 40, 20}, + {"t5", 0, 20, 50, 1, 20, 50, 3, 0, 20}, + + // varying perPage + {"t6", 1, 0, 50, 1, 100, 50, 1, 0, 100}, + {"t7", 1, -1, 50, 1, 100, 50, 1, 0, 100}, + {"t8", 1, 100, 50, 1, 100, 50, 1, 0, 100}, + {"t9", 1, 1001, 50, 1, 1000, 50, 1, 0, 1000}, + + // varying total + {"t10", 1, 20, 0, 1, 20, 0, 0, 0, 20}, + {"t11", 1, 20, -1, 1, 20, -1, -1, 0, 20}, + } + + for _, test := range tests { + p := New(test.page, test.perPage, test.total) + assert.Equal(t, test.expectedPage, p.Page, test.tag) + assert.Equal(t, test.expectedPerPage, p.PerPage, test.tag) + assert.Equal(t, test.expectedTotal, p.TotalCount, test.tag) + assert.Equal(t, test.pageCount, p.PageCount, test.tag) + assert.Equal(t, test.offset, p.Offset(), test.tag) + assert.Equal(t, test.limit, p.Limit(), test.tag) + } +} + +func TestPages_BuildLinkHeader(t *testing.T) { + baseURL := "/tokens" + defaultPerPage := 10 + tests := []struct { + tag string + page, perPage, total int + header string + }{ + {"t1", 1, 20, 50, "; rel=\"next\", ; rel=\"last\""}, + {"t2", 2, 20, 50, "; rel=\"first\", ; rel=\"prev\", ; rel=\"next\", ; rel=\"last\""}, + {"t3", 3, 20, 50, "; rel=\"first\", ; rel=\"prev\""}, + {"t4", 0, 20, 50, "; rel=\"next\", ; rel=\"last\""}, + {"t5", 4, 20, 50, "; rel=\"first\", ; rel=\"prev\""}, + {"t6", 1, 20, 0, ""}, + {"t7", 4, 20, -1, "; rel=\"first\", ; rel=\"prev\", ; rel=\"next\""}, + } + for _, test := range tests { + p := New(test.page, test.perPage, test.total) + assert.Equal(t, test.header, p.BuildLinkHeader(baseURL, defaultPerPage), test.tag) + } + + baseURL = "/tokens?from=10" + p := New(1, 20, 50) + assert.Equal(t, "; rel=\"next\", ; rel=\"last\"", p.BuildLinkHeader(baseURL, defaultPerPage)) +} + +func Test_parseInt(t *testing.T) { + type args struct { + value string + defaultValue int + } + tests := []struct { + name string + args args + want int + }{ + {"t1", args{"123", 100}, 123}, + {"t2", args{"", 100}, 100}, + {"t3", args{"a", 100}, 100}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseInt(tt.args.value, tt.args.defaultValue); got != tt.want { + t.Errorf("parseInt() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewFromRequest(t *testing.T) { + req, _ := http.NewRequest("GET", "http://example.com?page=2&per_page=20", bytes.NewBufferString("")) + p := NewFromRequest(req, 100) + assert.Equal(t, 2, p.Page) + assert.Equal(t, 20, p.PerPage) + assert.Equal(t, 100, p.TotalCount) + assert.Equal(t, 5, p.PageCount) +} diff --git a/testdata/testdata.sql b/testdata/testdata.sql new file mode 100644 index 0000000..14fe584 --- /dev/null +++ b/testdata/testdata.sql @@ -0,0 +1,6 @@ +INSERT INTO album (id, name, created_at, updated_at) +VALUES ('967d5bb5-3a7a-4d5e-8a6c-febc8c5b3f13', 'Hollywood''s Bleeding', '2019-10-01 15:36:38'::timestamp, '2019-10-01 15:36:38'::timestamp), + ('c809bf15-bc2c-4621-bb96-70af96fd5d67', 'AI YoungBoy 2', '2019-10-02 11:16:12'::timestamp, '2019-10-02 11:16:12'::timestamp), + ('2367710a-d4fb-49f5-8860-557b337386dd', 'KIRK', '2019-10-05 05:21:11'::timestamp, '2019-10-05 05:21:11'::timestamp), + ('b0a24f12-428f-4ff5-84d5-bc1fdcff6f03', 'Lover', '2019-10-11 19:43:18'::timestamp, '2019-10-11 19:43:18'::timestamp), + ('e0bb80ec-75a6-4348-bfc3-6ac1e89b195e', 'So Much Fun', '2019-10-12 12:16:02'::timestamp, '2019-10-12 12:16:02'::timestamp);